Merge pull request #14746 from hashicorp/jbardin/s3-consistency

store and verify s3 remote state checksum to avoid consistency issues.
This commit is contained in:
James Bardin 2017-05-24 16:47:57 -04:00 committed by GitHub
commit ef1d53934c
2 changed files with 314 additions and 5 deletions

View File

@ -2,10 +2,14 @@ package s3
import ( import (
"bytes" "bytes"
"crypto/md5"
"encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
@ -17,6 +21,9 @@ import (
"github.com/hashicorp/terraform/state/remote" "github.com/hashicorp/terraform/state/remote"
) )
// Store the last saved serial in dynamo with this suffix for consistency checks.
const stateIDSuffix = "-md5"
type RemoteClient struct { type RemoteClient struct {
s3Client *s3.S3 s3Client *s3.S3
dynClient *dynamodb.DynamoDB dynClient *dynamodb.DynamoDB
@ -28,7 +35,55 @@ type RemoteClient struct {
lockTable string lockTable string
} }
func (c *RemoteClient) Get() (*remote.Payload, error) { var (
// The amount of time we will retry a state waiting for it to match the
// expected checksum.
consistencyRetryTimeout = 10 * time.Second
// delay when polling the state
consistencyRetryPollInterval = 2 * time.Second
)
// test hook called when checksums don't match
var testChecksumHook func()
func (c *RemoteClient) Get() (payload *remote.Payload, err error) {
deadline := time.Now().Add(consistencyRetryTimeout)
// If we have a checksum, and the returned payload doesn't match, we retry
// up until deadline.
for {
payload, err = c.get()
if err != nil {
return nil, err
}
// verify that this state is what we expect
if expected, err := c.getMD5(); err != nil {
log.Printf("[WARNING] failed to fetch state md5: %s", err)
} else if len(expected) > 0 && !bytes.Equal(expected, payload.MD5) {
log.Printf("[WARNING] state md5 mismatch: expected '%x', got '%x'", expected, payload.MD5)
if testChecksumHook != nil {
testChecksumHook()
}
if time.Now().Before(deadline) {
time.Sleep(consistencyRetryPollInterval)
log.Println("[INFO] retrying S3 RemoteClient.Get...")
continue
}
return nil, fmt.Errorf(errBadChecksumFmt, payload.MD5)
}
break
}
return payload, err
}
func (c *RemoteClient) get() (*remote.Payload, error) {
output, err := c.s3Client.GetObject(&s3.GetObjectInput{ output, err := c.s3Client.GetObject(&s3.GetObjectInput{
Bucket: &c.bucketName, Bucket: &c.bucketName,
Key: &c.path, Key: &c.path,
@ -53,8 +108,10 @@ func (c *RemoteClient) Get() (*remote.Payload, error) {
return nil, fmt.Errorf("Failed to read remote state: %s", err) return nil, fmt.Errorf("Failed to read remote state: %s", err)
} }
sum := md5.Sum(buf.Bytes())
payload := &remote.Payload{ payload := &remote.Payload{
Data: buf.Bytes(), Data: buf.Bytes(),
MD5: sum[:],
} }
// If there was no data, then return nil // If there was no data, then return nil
@ -92,11 +149,20 @@ func (c *RemoteClient) Put(data []byte) error {
log.Printf("[DEBUG] Uploading remote state to S3: %#v", i) log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
if _, err := c.s3Client.PutObject(i); err == nil { _, err := c.s3Client.PutObject(i)
return nil if err != nil {
} else {
return fmt.Errorf("Failed to upload state: %v", err) return fmt.Errorf("Failed to upload state: %v", err)
} }
sum := md5.Sum(data)
if err := c.putMD5(sum[:]); err != nil {
// if this errors out, we unfortunately have to error out altogether,
// since the next Get will inevitably fail.
return fmt.Errorf("failed to store state MD5: %s", err)
}
return nil
} }
func (c *RemoteClient) Delete() error { func (c *RemoteClient) Delete() error {
@ -105,9 +171,17 @@ func (c *RemoteClient) Delete() error {
Key: &c.path, Key: &c.path,
}) })
if err != nil {
return err return err
} }
if err := c.deleteMD5(); err != nil {
log.Printf("error deleting state md5: %s", err)
}
return nil
}
func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) { func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
if c.lockTable == "" { if c.lockTable == "" {
return "", nil return "", nil
@ -146,9 +220,84 @@ func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
} }
return "", lockErr return "", lockErr
} }
return info.ID, nil return info.ID, nil
} }
func (c *RemoteClient) getMD5() ([]byte, error) {
if c.lockTable == "" {
return nil, nil
}
getParams := &dynamodb.GetItemInput{
Key: map[string]*dynamodb.AttributeValue{
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
},
ProjectionExpression: aws.String("LockID, Digest"),
TableName: aws.String(c.lockTable),
}
resp, err := c.dynClient.GetItem(getParams)
if err != nil {
return nil, err
}
var val string
if v, ok := resp.Item["Digest"]; ok && v.S != nil {
val = *v.S
}
sum, err := hex.DecodeString(val)
if err != nil || len(sum) != md5.Size {
return nil, errors.New("invalid md5")
}
return sum, nil
}
// store the hash of the state to that clients can check for stale state files.
func (c *RemoteClient) putMD5(sum []byte) error {
if c.lockTable == "" {
return nil
}
if len(sum) != md5.Size {
return errors.New("invalid payload md5")
}
putParams := &dynamodb.PutItemInput{
Item: map[string]*dynamodb.AttributeValue{
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
"Digest": {S: aws.String(hex.EncodeToString(sum))},
},
TableName: aws.String(c.lockTable),
}
_, err := c.dynClient.PutItem(putParams)
if err != nil {
log.Printf("[WARNING] failed to record state serial in dynamodb: %s", err)
}
return nil
}
// remove the hash value for a deleted state
func (c *RemoteClient) deleteMD5() error {
if c.lockTable == "" {
return nil
}
params := &dynamodb.DeleteItemInput{
Key: map[string]*dynamodb.AttributeValue{
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
},
TableName: aws.String(c.lockTable),
}
if _, err := c.dynClient.DeleteItem(params); err != nil {
return err
}
return nil
}
func (c *RemoteClient) getLockInfo() (*state.LockInfo, error) { func (c *RemoteClient) getLockInfo() (*state.LockInfo, error) {
getParams := &dynamodb.GetItemInput{ getParams := &dynamodb.GetItemInput{
Key: map[string]*dynamodb.AttributeValue{ Key: map[string]*dynamodb.AttributeValue{
@ -217,3 +366,12 @@ func (c *RemoteClient) Unlock(id string) error {
func (c *RemoteClient) lockPath() string { func (c *RemoteClient) lockPath() string {
return fmt.Sprintf("%s/%s", c.bucketName, c.path) return fmt.Sprintf("%s/%s", c.bucketName, c.path)
} }
const errBadChecksumFmt = `state data in S3 does not have the expected content.
This may be caused by unusually long delays in S3 processing a previous state
update. Please wait for a minute or two and try again. If this problem
persists, and neither S3 nor DynamoDB are experiencing an outage, you may need
to manually verify the remote state and update the Digest value stored in the
DynamoDB table to the following value: %x
`

View File

@ -1,13 +1,17 @@
package s3 package s3
import ( import (
"bytes"
"crypto/md5"
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
"github.com/hashicorp/terraform/backend" "github.com/hashicorp/terraform/backend"
"github.com/hashicorp/terraform/state" "github.com/hashicorp/terraform/state"
"github.com/hashicorp/terraform/state/remote" "github.com/hashicorp/terraform/state/remote"
"github.com/hashicorp/terraform/terraform"
) )
func TestRemoteClient_impl(t *testing.T) { func TestRemoteClient_impl(t *testing.T) {
@ -150,3 +154,150 @@ func TestForceUnlock(t *testing.T) {
t.Fatal("failed to force-unlock named state") t.Fatal("failed to force-unlock named state")
} }
} }
func TestRemoteClient_clientMD5(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"
b := backend.TestBackendConfig(t, New(), map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"lock_table": bucketName,
}).(*Backend)
createDynamoDBTable(t, b.dynClient, bucketName)
defer deleteDynamoDBTable(t, b.dynClient, bucketName)
s, err := b.State(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client := s.(*remote.State).Client.(*RemoteClient)
sum := md5.Sum([]byte("test"))
if err := client.putMD5(sum[:]); err != nil {
t.Fatal(err)
}
getSum, err := client.getMD5()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(getSum, sum[:]) {
t.Fatalf("getMD5 returned the wrong checksum: expected %x, got %x", sum[:], getSum)
}
if err := client.deleteMD5(); err != nil {
t.Fatal(err)
}
if getSum, err := client.getMD5(); err == nil {
t.Fatalf("expected getMD5 error, got none. checksum: %x", getSum)
}
}
// verify that a client won't return a state with an incorrect checksum.
func TestRemoteClient_stateChecksum(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"
b1 := backend.TestBackendConfig(t, New(), map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"lock_table": bucketName,
}).(*Backend)
createS3Bucket(t, b1.s3Client, bucketName)
defer deleteS3Bucket(t, b1.s3Client, bucketName)
createDynamoDBTable(t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(t, b1.dynClient, bucketName)
s1, err := b1.State(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client1 := s1.(*remote.State).Client
// create a old and new state version to persist
s := state.TestStateInitial()
var oldState bytes.Buffer
if err := terraform.WriteState(s, &oldState); err != nil {
t.Fatal(err)
}
s.Serial++
var newState bytes.Buffer
if err := terraform.WriteState(s, &newState); err != nil {
t.Fatal(err)
}
// Use b2 without a lock_table to bypass the lock table to write the state directly.
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
b2 := backend.TestBackendConfig(t, New(), map[string]interface{}{
"bucket": bucketName,
"key": keyName,
}).(*Backend)
s2, err := b2.State(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client2 := s2.(*remote.State).Client
// write the new state through client2 so that there is no checksum yet
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
// verify that we can pull a state without a checksum
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
// write the new state back with its checksum
if err := client1.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
// put the old state in place of the new, without updating the checksum
if err := client2.Put(oldState.Bytes()); err != nil {
t.Fatal(err)
}
// remove the timeouts so we can fail immediately
origTimeout := consistencyRetryTimeout
origInterval := consistencyRetryPollInterval
defer func() {
consistencyRetryTimeout = origTimeout
consistencyRetryPollInterval = origInterval
}()
consistencyRetryTimeout = 0
consistencyRetryPollInterval = 0
// fetching the state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}
// update the state with the correct one after we Get again
testChecksumHook = func() {
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
testChecksumHook = nil
}
consistencyRetryTimeout = origTimeout
// this final Get will fail to fail the checksum verification, the above
// callback will update the state with the correct version, and Get should
// retry automatically.
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
}