diff --git a/state/remote/atlas.go b/state/remote/atlas.go index 9a93730f3..f33f407ce 100644 --- a/state/remote/atlas.go +++ b/state/remote/atlas.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "fmt" "io" + "log" "net/http" "net/url" "os" @@ -13,6 +14,7 @@ import ( "strings" "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/terraform/terraform" ) const ( @@ -75,6 +77,9 @@ type AtlasClient struct { Name string AccessToken string RunId string + HTTPClient *http.Client + + conflictHandlingAttempted bool } func (c *AtlasClient) Get() (*Payload, error) { @@ -85,7 +90,7 @@ func (c *AtlasClient) Get() (*Payload, error) { } // Request the url - client := cleanhttp.DefaultClient() + client := c.http() resp, err := client.Do(req) if err != nil { return nil, err @@ -164,7 +169,7 @@ func (c *AtlasClient) Put(state []byte) error { req.ContentLength = int64(len(state)) // Make the request - client := cleanhttp.DefaultClient() + client := c.http() resp, err := client.Do(req) if err != nil { return fmt.Errorf("Failed to upload state: %v", err) @@ -175,6 +180,8 @@ func (c *AtlasClient) Put(state []byte) error { switch resp.StatusCode { case http.StatusOK: return nil + case http.StatusConflict: + return c.handleConflict(c.readBody(resp.Body), state) default: return fmt.Errorf( "HTTP error: %d\n\nBody: %s", @@ -190,7 +197,7 @@ func (c *AtlasClient) Delete() error { } // Make the request - client := cleanhttp.DefaultClient() + client := c.http() resp, err := client.Do(req) if err != nil { return fmt.Errorf("Failed to delete state: %v", err) @@ -241,3 +248,74 @@ func (c *AtlasClient) url() *url.URL { RawQuery: values.Encode(), } } + +func (c *AtlasClient) http() *http.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } + return cleanhttp.DefaultClient() +} + +// Atlas returns an HTTP 409 - Conflict if the pushed state reports the same +// Serial number but the checksum of the raw content differs. This can +// sometimes happen when Terraform changes state representation internally +// between versions in a way that's semantically neutral but affects the JSON +// output and therefore the checksum. +// +// Here we detect and handle this situation by ticking the serial and retrying +// iff for the previous state and the proposed state: +// +// * the serials match +// * the parsed states are Equal (semantically equivalent) +// +// In other words, in this situation Terraform can override Atlas's detected +// conflict by asserting that the state it is pushing is indeed correct. +func (c *AtlasClient) handleConflict(msg string, state []byte) error { + log.Printf("[DEBUG] Handling Atlas conflict response: %s", msg) + + if c.conflictHandlingAttempted { + log.Printf("[DEBUG] Already attempted conflict resolution; returning conflict.") + } else { + c.conflictHandlingAttempted = true + log.Printf("[DEBUG] Atlas reported conflict, checking for equivalent states.") + + payload, err := c.Get() + if err != nil { + return conflictHandlingError(err) + } + + currentState, err := terraform.ReadState(bytes.NewReader(payload.Data)) + if err != nil { + return conflictHandlingError(err) + } + + proposedState, err := terraform.ReadState(bytes.NewReader(state)) + if err != nil { + return conflictHandlingError(err) + } + + if statesAreEquivalent(currentState, proposedState) { + log.Printf("[DEBUG] States are equivalent, incrementing serial and retrying.") + proposedState.Serial++ + var buf bytes.Buffer + if err := terraform.WriteState(proposedState, &buf); err != nil { + return conflictHandlingError(err) + } + return c.Put(buf.Bytes()) + } else { + log.Printf("[DEBUG] States are not equivalent, returning conflict.") + } + } + + return fmt.Errorf( + "Atlas detected a remote state conflict.\n\nMessage: %s", msg) +} + +func conflictHandlingError(err error) error { + return fmt.Errorf( + "Error while handling a conflict response from Atlas: %s", err) +} + +func statesAreEquivalent(current, proposed *terraform.State) bool { + return current.Serial == proposed.Serial && current.Equal(proposed) +} diff --git a/state/remote/atlas_test.go b/state/remote/atlas_test.go index 202e15dad..ae7ee8a1b 100644 --- a/state/remote/atlas_test.go +++ b/state/remote/atlas_test.go @@ -1,9 +1,15 @@ package remote import ( + "bytes" + "crypto/md5" "net/http" + "net/http/httptest" "os" "testing" + "time" + + "github.com/hashicorp/terraform/terraform" ) func TestAtlasClient_impl(t *testing.T) { @@ -30,3 +36,259 @@ func TestAtlasClient(t *testing.T) { testClient(t, client) } + +func TestAtlasClient_ReportedConflictEqualStates(t *testing.T) { + fakeAtlas := newFakeAtlas(t, testStateModuleOrderChange) + srv := fakeAtlas.Server() + defer srv.Close() + client, err := atlasFactory(map[string]string{ + "access_token": "sometoken", + "name": "someuser/some-test-remote-state", + "address": srv.URL, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + state, err := terraform.ReadState(bytes.NewReader(testStateModuleOrderChange)) + if err != nil { + t.Fatalf("err: %s", err) + } + + var stateJson bytes.Buffer + if err := terraform.WriteState(state, &stateJson); err != nil { + t.Fatalf("err: %s", err) + } + if err := client.Put(stateJson.Bytes()); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestAtlasClient_NoConflict(t *testing.T) { + fakeAtlas := newFakeAtlas(t, testStateSimple) + srv := fakeAtlas.Server() + defer srv.Close() + client, err := atlasFactory(map[string]string{ + "access_token": "sometoken", + "name": "someuser/some-test-remote-state", + "address": srv.URL, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + state, err := terraform.ReadState(bytes.NewReader(testStateSimple)) + if err != nil { + t.Fatalf("err: %s", err) + } + + fakeAtlas.NoConflictAllowed(true) + + var stateJson bytes.Buffer + if err := terraform.WriteState(state, &stateJson); err != nil { + t.Fatalf("err: %s", err) + } + if err := client.Put(stateJson.Bytes()); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestAtlasClient_LegitimateConflict(t *testing.T) { + fakeAtlas := newFakeAtlas(t, testStateSimple) + srv := fakeAtlas.Server() + defer srv.Close() + client, err := atlasFactory(map[string]string{ + "access_token": "sometoken", + "name": "someuser/some-test-remote-state", + "address": srv.URL, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + state, err := terraform.ReadState(bytes.NewReader(testStateSimple)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Changing the state but not the serial. Should generate a conflict. + state.RootModule().Outputs["drift"] = "happens" + + var stateJson bytes.Buffer + if err := terraform.WriteState(state, &stateJson); err != nil { + t.Fatalf("err: %s", err) + } + if err := client.Put(stateJson.Bytes()); err == nil { + t.Fatal("Expected error from state conflict, got none.") + } +} + +func TestAtlasClient_UnresolvableConflict(t *testing.T) { + fakeAtlas := newFakeAtlas(t, testStateSimple) + + // Something unexpected causes Atlas to conflict in a way that we can't fix. + fakeAtlas.AlwaysConflict(true) + + srv := fakeAtlas.Server() + defer srv.Close() + client, err := atlasFactory(map[string]string{ + "access_token": "sometoken", + "name": "someuser/some-test-remote-state", + "address": srv.URL, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + state, err := terraform.ReadState(bytes.NewReader(testStateSimple)) + if err != nil { + t.Fatalf("err: %s", err) + } + + var stateJson bytes.Buffer + if err := terraform.WriteState(state, &stateJson); err != nil { + t.Fatalf("err: %s", err) + } + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + if err := client.Put(stateJson.Bytes()); err == nil { + t.Fatal("Expected error from state conflict, got none.") + } + }() + + select { + case <-doneCh: + // OK + case <-time.After(50 * time.Millisecond): + t.Fatalf("Timed out after 50ms, probably because retrying infinitely.") + } +} + +// Stub Atlas HTTP API for a given state JSON string; does checksum-based +// conflict detection equivalent to Atlas's. +type fakeAtlas struct { + state []byte + t *testing.T + + // Used to test that we only do the special conflict handling retry once. + alwaysConflict bool + + // Used to fail the test immediately if a conflict happens. + noConflictAllowed bool +} + +func newFakeAtlas(t *testing.T, state []byte) *fakeAtlas { + return &fakeAtlas{ + state: state, + t: t, + } +} + +func (f *fakeAtlas) Server() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(f.handler)) +} + +func (f *fakeAtlas) CurrentState() *terraform.State { + currentState, err := terraform.ReadState(bytes.NewReader(f.state)) + if err != nil { + f.t.Fatalf("err: %s", err) + } + return currentState +} + +func (f *fakeAtlas) CurrentSerial() int64 { + return f.CurrentState().Serial +} + +func (f *fakeAtlas) CurrentSum() [md5.Size]byte { + return md5.Sum(f.state) +} + +func (f *fakeAtlas) AlwaysConflict(b bool) { + f.alwaysConflict = b +} + +func (f *fakeAtlas) NoConflictAllowed(b bool) { + f.noConflictAllowed = b +} + +func (f *fakeAtlas) handler(resp http.ResponseWriter, req *http.Request) { + switch req.Method { + case "GET": + // Respond with the current stored state. + resp.Header().Set("Content-Type", "application/json") + resp.Write(f.state) + case "PUT": + var buf bytes.Buffer + buf.ReadFrom(req.Body) + sum := md5.Sum(buf.Bytes()) + state, err := terraform.ReadState(&buf) + if err != nil { + f.t.Fatalf("err: %s", err) + } + conflict := f.CurrentSerial() == state.Serial && f.CurrentSum() != sum + conflict = conflict || f.alwaysConflict + if conflict { + if f.noConflictAllowed { + f.t.Fatal("Got conflict when NoConflictAllowed was set.") + } + http.Error(resp, "Conflict", 409) + } else { + f.state = buf.Bytes() + resp.WriteHeader(200) + } + } +} + +// This is a tfstate file with the module order changed, which is a structural +// but not a semantic difference. Terraform will sort these modules as it +// loads the state. +var testStateModuleOrderChange = []byte( + `{ + "version": 1, + "serial": 1, + "modules": [ + { + "path": [ + "root", + "child2", + "grandchild" + ], + "outputs": { + "foo": "bar2" + }, + "resources": null + }, + { + "path": [ + "root", + "child1", + "grandchild" + ], + "outputs": { + "foo": "bar1" + }, + "resources": null + } + ] +} +`) + +var testStateSimple = []byte( + `{ + "version": 1, + "serial": 1, + "modules": [ + { + "path": [ + "root" + ], + "outputs": { + "foo": "bar" + }, + "resources": null + } + ] +} +`)