diff --git a/backend/remote/backend_state.go b/backend/remote/backend_state.go index 5089c4075..29dc8550e 100644 --- a/backend/remote/backend_state.go +++ b/backend/remote/backend_state.go @@ -20,6 +20,7 @@ type remoteClient struct { runID string stateUploadErr bool workspace *tfe.Workspace + forcePush bool } // Get the remote state. @@ -69,6 +70,7 @@ func (r *remoteClient) Put(state []byte) error { Serial: tfe.Int64(int64(stateFile.Serial)), MD5: tfe.String(fmt.Sprintf("%x", md5.Sum(state))), State: tfe.String(base64.StdEncoding.EncodeToString(state)), + Force: tfe.Bool(r.forcePush), } // If we have a run ID, make sure to add it to the options @@ -97,6 +99,12 @@ func (r *remoteClient) Delete() error { return nil } +// EnableForcePush to allow the remote client to overwrite state +// by implementing remote.ClientForcePusher +func (r *remoteClient) EnableForcePush() { + r.forcePush = true +} + // Lock the remote state. func (r *remoteClient) Lock(info *state.LockInfo) (string, error) { ctx := context.Background() diff --git a/state/remote/remote.go b/state/remote/remote.go index 58f2578d2..f9947bac7 100644 --- a/state/remote/remote.go +++ b/state/remote/remote.go @@ -15,6 +15,14 @@ type Client interface { Delete() error } +// ClientForcePusher is an optional interface that allows a remote +// state to force push by managing a flag on the client that is +// toggled on by a call to EnableForcePush. +type ClientForcePusher interface { + Client + EnableForcePush() +} + // ClientLocker is an optional interface that allows a remote state // backend to enable state lock/unlock. type ClientLocker interface { diff --git a/state/remote/remote_test.go b/state/remote/remote_test.go index 77cdd5b4f..b88766c85 100644 --- a/state/remote/remote_test.go +++ b/state/remote/remote_test.go @@ -69,6 +69,7 @@ func (c nilClient) Delete() error { return nil } type mockClient struct { current []byte log []mockClientRequest + force bool } type mockClientRequest struct { @@ -89,7 +90,11 @@ func (c *mockClient) Get() (*Payload, error) { } func (c *mockClient) Put(data []byte) error { - c.appendLog("Put", data) + if c.force { + c.appendLog("Force Put", data) + } else { + c.appendLog("Put", data) + } c.current = data return nil } @@ -100,6 +105,11 @@ func (c *mockClient) Delete() error { return nil } +// Implements remote.ClientForcePusher +func (c *mockClient) EnableForcePush() { + c.force = true +} + func (c *mockClient) appendLog(method string, content []byte) { // For easier test assertions, we actually log the result of decoding // the content JSON rather than the raw bytes. Callers are in principle diff --git a/state/remote/state.go b/state/remote/state.go index 6d701da80..3069aeb89 100644 --- a/state/remote/state.go +++ b/state/remote/state.go @@ -21,10 +21,19 @@ type State struct { Client Client - lineage string - serial uint64 - state, readState *states.State - disableLocks bool + // We track two pieces of meta data in addition to the state itself: + // + // lineage - the state's unique ID + // serial - the monotonic counter of "versions" of the state + // + // Both of these (along with state) have a sister field + // that represents the values read in from an existing source. + // All three of these values are used to determine if the new + // state has changed from an existing state we read in. + lineage, readLineage string + serial, readSerial uint64 + state, readState *states.State + disableLocks bool } var _ statemgr.Full = (*State)(nil) @@ -64,8 +73,15 @@ func (s *State) WriteStateForMigration(f *statefile.File, force bool) error { s.mu.Lock() defer s.mu.Unlock() - checkFile := statefile.New(s.state, s.lineage, s.serial) - if !force { + // `force` is passed down from the CLI flag and terminates here. Actual + // force pushing with the remote backend happens when Put()'ing the contents + // in the backend. If force is specified we skip verifications and hand the + // context off to the client to use when persitence operations actually take place. + c, isForcePusher := s.Client.(ClientForcePusher) + if force && isForcePusher { + c.EnableForcePush() + } else { + checkFile := statefile.New(s.state, s.lineage, s.serial) if err := statemgr.CheckValidImport(f, checkFile); err != nil { return err } @@ -113,7 +129,12 @@ func (s *State) refreshState() error { s.lineage = stateFile.Lineage s.serial = stateFile.Serial s.state = stateFile.State - s.readState = s.state.DeepCopy() // our states must be separate instances so we can track changes + + // Properties from the remote must be separate so we can + // track changes as lineage, serial and/or state are mutated + s.readLineage = stateFile.Lineage + s.readSerial = stateFile.Serial + s.readState = s.state.DeepCopy() return nil } @@ -123,8 +144,11 @@ func (s *State) PersistState() error { defer s.mu.Unlock() if s.readState != nil { - if statefile.StatesMarshalEqual(s.state, s.readState) { - // If the state hasn't changed at all then we have nothing to do. + lineageUnchanged := s.readLineage != "" && s.lineage == s.readLineage + serialUnchanged := s.readSerial != 0 && s.serial == s.readSerial + stateUnchanged := statefile.StatesMarshalEqual(s.state, s.readState) + if stateUnchanged && lineageUnchanged && serialUnchanged { + // If the state, lineage or serial haven't changed at all then we have nothing to do. return nil } s.serial++ @@ -161,7 +185,13 @@ func (s *State) PersistState() error { // After we've successfully persisted, what we just wrote is our new // reference state until someone calls RefreshState again. + // We've potentially overwritten (via force) the state, lineage + // and / or serial (and serial was incremented) so we copy over all + // three fields so everything matches the new state and a subsequent + // operation would correctly detect no changes to the lineage, serial or state. s.readState = s.state.DeepCopy() + s.readLineage = s.lineage + s.readSerial = s.serial return nil } diff --git a/state/remote/state_test.go b/state/remote/state_test.go index 949eefe11..413c6e75e 100644 --- a/state/remote/state_test.go +++ b/state/remote/state_test.go @@ -1,6 +1,7 @@ package remote import ( + "log" "sync" "testing" @@ -8,6 +9,7 @@ import ( "github.com/zclconf/go-cty/cty" "github.com/hashicorp/terraform/states" + "github.com/hashicorp/terraform/states/statefile" "github.com/hashicorp/terraform/states/statemgr" "github.com/hashicorp/terraform/version" ) @@ -41,12 +43,176 @@ func TestStateRace(t *testing.T) { wg.Wait() } +// testCase encapsulates a test state test +type testCase struct { + name string + // A function to mutate state and return a cleanup function + mutationFunc func(*State) (*states.State, func()) + // The expected request to have taken place + expectedRequest mockClientRequest + // Mark this case as not having a request + noRequest bool +} + +// isRequested ensures a test that is specified as not having +// a request doesn't have one by checking if a method exists +// on the expectedRequest. +func (tc testCase) isRequested(t *testing.T) bool { + hasMethod := tc.expectedRequest.Method != "" + if tc.noRequest && hasMethod { + t.Fatalf("expected no content for %q but got: %v", tc.name, tc.expectedRequest) + } + return !tc.noRequest +} + func TestStatePersist(t *testing.T) { + testCases := []testCase{ + // Refreshing state before we run the test loop causes a GET + { + name: "refresh state", + mutationFunc: func(mgr *State) (*states.State, func()) { + return mgr.State(), func() {} + }, + expectedRequest: mockClientRequest{ + Method: "Get", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "mock-lineage", + "serial": 1.0, // encoding/json decodes this as float64 by default + "terraform_version": "0.0.0", + "outputs": map[string]interface{}{}, + "resources": []interface{}{}, + }, + }, + }, + { + name: "change lineage", + mutationFunc: func(mgr *State) (*states.State, func()) { + originalLineage := mgr.lineage + mgr.lineage = "some-new-lineage" + return mgr.State(), func() { + mgr.lineage = originalLineage + } + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "some-new-lineage", + "serial": 2.0, // encoding/json decodes this as float64 by default + "terraform_version": version.Version, + "outputs": map[string]interface{}{}, + "resources": []interface{}{}, + }, + }, + }, + { + name: "change serial", + mutationFunc: func(mgr *State) (*states.State, func()) { + originalSerial := mgr.serial + mgr.serial++ + return mgr.State(), func() { + mgr.serial = originalSerial + } + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "mock-lineage", + "serial": 4.0, // encoding/json decodes this as float64 by default + "terraform_version": version.Version, + "outputs": map[string]interface{}{}, + "resources": []interface{}{}, + }, + }, + }, + { + name: "add output to state", + mutationFunc: func(mgr *State) (*states.State, func()) { + s := mgr.State() + s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false) + return s, func() {} + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "mock-lineage", + "serial": 3.0, // encoding/json decodes this as float64 by default + "terraform_version": version.Version, + "outputs": map[string]interface{}{ + "foo": map[string]interface{}{ + "type": "string", + "value": "bar", + }, + }, + "resources": []interface{}{}, + }, + }, + }, + { + name: "mutate state bar -> baz", + mutationFunc: func(mgr *State) (*states.State, func()) { + s := mgr.State() + s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false) + return s, func() {} + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "mock-lineage", + "serial": 4.0, // encoding/json decodes this as float64 by default + "terraform_version": version.Version, + "outputs": map[string]interface{}{ + "foo": map[string]interface{}{ + "type": "string", + "value": "baz", + }, + }, + "resources": []interface{}{}, + }, + }, + }, + { + name: "nothing changed", + mutationFunc: func(mgr *State) (*states.State, func()) { + s := mgr.State() + return s, func() {} + }, + noRequest: true, + }, + { + name: "reset serial (force push style)", + mutationFunc: func(mgr *State) (*states.State, func()) { + mgr.serial = 2 + return mgr.State(), func() {} + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, // encoding/json decodes this as float64 by default + "lineage": "mock-lineage", + "serial": 3.0, // encoding/json decodes this as float64 by default + "terraform_version": version.Version, + "outputs": map[string]interface{}{ + "foo": map[string]interface{}{ + "type": "string", + "value": "baz", + }, + }, + "resources": []interface{}{}, + }, + }, + }, + } + + // Initial setup of state just to give us a fixed starting point for our + // test assertions below, or else we'd need to deal with + // random lineage. mgr := &State{ Client: &mockClient{ - // Initial state just to give us a fixed starting point for our - // test assertions below, or else we'd need to deal with - // random lineage. current: []byte(` { "version": 4, @@ -62,94 +228,218 @@ func TestStatePersist(t *testing.T) { // In normal use (during a Terraform operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. + // NB This causes a GET to be logged so the first item in the test cases + // must account for this if err := mgr.RefreshState(); err != nil { t.Fatalf("failed to RefreshState: %s", err) } - s := mgr.State() - s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false) - if err := mgr.WriteState(s); err != nil { - t.Fatalf("failed to WriteState: %s", err) - } - if err := mgr.PersistState(); err != nil { - t.Fatalf("failed to PersistState: %s", err) - } + // Our client is a mockClient which has a log we + // use to check that operations generate expected requests + mockClient := mgr.Client.(*mockClient) - // Persisting the same state again should be a no-op: it doesn't fail, - // but it ought not to appear in the client's log either. - if err := mgr.WriteState(s); err != nil { - t.Fatalf("failed to WriteState: %s", err) - } - if err := mgr.PersistState(); err != nil { - t.Fatalf("failed to PersistState: %s", err) - } + // logIdx tracks the current index of the log separate from + // the loop iteration so we can check operations that don't + // cause any requests to be generated + logIdx := 0 - // ...but if we _do_ change something in the state then we should see - // it re-persist. - s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false) - if err := mgr.WriteState(s); err != nil { - t.Fatalf("failed to WriteState: %s", err) + // Run tests in order. + for _, tc := range testCases { + s, cleanup := tc.mutationFunc(mgr) + + if err := mgr.WriteState(s); err != nil { + t.Fatalf("failed to WriteState for %q: %s", tc.name, err) + } + if err := mgr.PersistState(); err != nil { + t.Fatalf("failed to PersistState for %q: %s", tc.name, err) + } + + if tc.isRequested(t) { + // Get captured request from the mock client log + // based on the index of the current test + if logIdx >= len(mockClient.log) { + t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) + } + loggedRequest := mockClient.log[logIdx] + logIdx++ + if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 { + t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff) + } + } + cleanup() } - if err := mgr.PersistState(); err != nil { - t.Fatalf("failed to PersistState: %s", err) - } - - got := mgr.Client.(*mockClient).log - want := []mockClientRequest{ - // The initial fetch from mgr.RefreshState above. - { - Method: "Get", - Content: map[string]interface{}{ - "version": 4.0, // encoding/json decodes this as float64 by default - "lineage": "mock-lineage", - "serial": 1.0, // encoding/json decodes this as float64 by default - "terraform_version": "0.0.0", - "outputs": map[string]interface{}{}, - "resources": []interface{}{}, - }, - }, - - // First call to PersistState, with output "foo" set to "bar". - { - Method: "Put", - Content: map[string]interface{}{ - "version": 4.0, - "lineage": "mock-lineage", - "serial": 2.0, // serial increases because the outputs changed - "terraform_version": version.Version, - "outputs": map[string]interface{}{ - "foo": map[string]interface{}{ - "type": "string", - "value": "bar", - }, - }, - "resources": []interface{}{}, - }, - }, - - // Second call to PersistState generates no client requests, because - // nothing changed in the state itself. - - // Third call to PersistState, with the "foo" output value updated - // to "baz". - { - Method: "Put", - Content: map[string]interface{}{ - "version": 4.0, - "lineage": "mock-lineage", - "serial": 3.0, // serial increases because the outputs changed - "terraform_version": version.Version, - "outputs": map[string]interface{}{ - "foo": map[string]interface{}{ - "type": "string", - "value": "baz", - }, - }, - "resources": []interface{}{}, - }, - }, - } - if diff := cmp.Diff(want, got); len(diff) > 0 { - t.Errorf("incorrect client requests\n%s", diff) + logCnt := len(mockClient.log) + if logIdx != logCnt { + log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx) + } +} + +type migrationTestCase struct { + name string + // A function to generate a statefile + stateFile func(*State) *statefile.File + // The expected request to have taken place + expectedRequest mockClientRequest + // Mark this case as not having a request + expectedError string + // force flag passed to client + force bool +} + +func TestWriteStateForMigration(t *testing.T) { + mgr := &State{ + Client: &mockClient{ + current: []byte(` + { + "version": 4, + "lineage": "mock-lineage", + "serial": 3, + "terraform_version":"0.0.0", + "outputs": {"foo": {"value":"bar", "type": "string"}}, + "resources": [] + } + `), + }, + } + + testCases := []migrationTestCase{ + // Refreshing state before we run the test loop causes a GET + { + name: "refresh state", + stateFile: func(mgr *State) *statefile.File { + return mgr.StateForMigration() + }, + expectedRequest: mockClientRequest{ + Method: "Get", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "mock-lineage", + "serial": 3.0, + "terraform_version": "0.0.0", + "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, + "resources": []interface{}{}, + }, + }, + }, + { + name: "cannot import lesser serial without force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, mgr.lineage, 1) + }, + expectedError: "cannot import state with serial 1 over newer state with serial 3", + }, + { + name: "cannot import differing lineage without force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, "different-lineage", mgr.serial) + }, + expectedError: `cannot import state with lineage "different-lineage" over unrelated state with lineage "mock-lineage"`, + }, + { + name: "can import lesser serial with force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, mgr.lineage, 1) + }, + expectedRequest: mockClientRequest{ + Method: "Force Put", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "mock-lineage", + "serial": 2.0, + "terraform_version": version.Version, + "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, + "resources": []interface{}{}, + }, + }, + force: true, + }, + { + name: "cannot import differing lineage without force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, "different-lineage", mgr.serial) + }, + expectedRequest: mockClientRequest{ + Method: "Force Put", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "different-lineage", + "serial": 3.0, + "terraform_version": version.Version, + "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, + "resources": []interface{}{}, + }, + }, + force: true, + }, + } + + // In normal use (during a Terraform operation) we always refresh and read + // before any writes would happen, so we'll mimic that here for realism. + // NB This causes a GET to be logged so the first item in the test cases + // must account for this + if err := mgr.RefreshState(); err != nil { + t.Fatalf("failed to RefreshState: %s", err) + } + + if err := mgr.WriteState(mgr.State()); err != nil { + t.Fatalf("failed to write initial state: %s", err) + } + + // Our client is a mockClient which has a log we + // use to check that operations generate expected requests + mockClient := mgr.Client.(*mockClient) + + if mockClient.force { + t.Fatalf("client should not default to force") + } + + // logIdx tracks the current index of the log separate from + // the loop iteration so we can check operations that don't + // cause any requests to be generated + logIdx := 0 + + for _, tc := range testCases { + // Always reset client to not be force pushing + mockClient.force = false + sf := tc.stateFile(mgr) + err := mgr.WriteStateForMigration(sf, tc.force) + shouldError := tc.expectedError != "" + + // If we are expecting and error check it and move on + if shouldError { + if err == nil { + t.Fatalf("test case %q should have failed with error %q", tc.name, tc.expectedError) + } else if err.Error() != tc.expectedError { + t.Fatalf("test case %q expected error %q but got %q", tc.name, tc.expectedError, err) + } + continue + } + + if err != nil { + t.Fatalf("test case %q failed: %v", tc.name, err) + } + + if tc.force && !mockClient.force { + t.Fatalf("test case %q should have enabled force push", tc.name) + } + + // At this point we should just do a normal write and persist + // as would happen from the CLI + mgr.WriteState(mgr.State()) + mgr.PersistState() + + if logIdx >= len(mockClient.log) { + t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) + } + loggedRequest := mockClient.log[logIdx] + logIdx++ + if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 { + t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff) + } + } + + logCnt := len(mockClient.log) + if logIdx != logCnt { + log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx) } }