diff --git a/states/remote/remote_test.go b/states/remote/remote_test.go index bbd73ae5f..1e8edc8b8 100644 --- a/states/remote/remote_test.go +++ b/states/remote/remote_test.go @@ -69,7 +69,6 @@ func (c nilClient) Delete() error { return nil } type mockClient struct { current []byte log []mockClientRequest - force bool } type mockClientRequest struct { @@ -90,11 +89,7 @@ func (c *mockClient) Get() (*Payload, error) { } func (c *mockClient) Put(data []byte) error { - if c.force { - c.appendLog("Force Put", data) - } else { - c.appendLog("Put", data) - } + c.appendLog("Put", data) c.current = data return nil } @@ -105,11 +100,6 @@ func (c *mockClient) Delete() error { return nil } -// Implements remote.ClientForcePusher -func (c *mockClient) EnableForcePush() { - c.force = true -} - func (c *mockClient) appendLog(method string, content []byte) { // For easier test assertions, we actually log the result of decoding // the content JSON rather than the raw bytes. Callers are in principle @@ -126,3 +116,54 @@ func (c *mockClient) appendLog(method string, content []byte) { } c.log = append(c.log, mockClientRequest{method, contentVal}) } + +// mockClientForcePusher is like mockClient, but also implements +// EnableForcePush, allowing testing for this behavior +type mockClientForcePusher struct { + current []byte + force bool + log []mockClientRequest +} + +func (c *mockClientForcePusher) Get() (*Payload, error) { + c.appendLog("Get", c.current) + if c.current == nil { + return nil, nil + } + checksum := md5.Sum(c.current) + return &Payload{ + Data: c.current, + MD5: checksum[:], + }, nil +} + +func (c *mockClientForcePusher) Put(data []byte) error { + if c.force { + c.appendLog("Force Put", data) + } else { + c.appendLog("Put", data) + } + c.current = data + return nil +} + +// Implements remote.ClientForcePusher +func (c *mockClientForcePusher) EnableForcePush() { + c.force = true +} + +func (c *mockClientForcePusher) Delete() error { + c.appendLog("Delete", c.current) + c.current = nil + return nil +} +func (c *mockClientForcePusher) appendLog(method string, content []byte) { + var contentVal map[string]interface{} + if content != nil { + err := json.Unmarshal(content, &contentVal) + if err != nil { + panic(err) // should never happen because our tests control this input + } + } + c.log = append(c.log, mockClientRequest{method, contentVal}) +} diff --git a/states/remote/state.go b/states/remote/state.go index bd9494f53..f4abd3fc8 100644 --- a/states/remote/state.go +++ b/states/remote/state.go @@ -72,20 +72,21 @@ func (s *State) WriteStateForMigration(f *statefile.File, force bool) error { s.mu.Lock() defer s.mu.Unlock() - // `force` is passed down from the CLI flag and terminates here. Actual - // force pushing with the remote backend happens when Put()'ing the contents - // in the backend. If force is specified we skip verifications and hand the - // context off to the client to use when persitence operations actually take place. - c, isForcePusher := s.Client.(ClientForcePusher) - if force && isForcePusher { - c.EnableForcePush() - } else { + if !force { checkFile := statefile.New(s.state, s.lineage, s.serial) if err := statemgr.CheckValidImport(f, checkFile); err != nil { return err } } + // The remote backend needs to pass the `force` flag through to its client. + // For backends that support such operations, inform the client + // that a force push has been requested + c, isForcePusher := s.Client.(ClientForcePusher) + if force && isForcePusher { + c.EnableForcePush() + } + // We create a deep copy of the state here, because the caller also has // a reference to the given object and can potentially go on to mutate // it after we return, but we want the snapshot at this point in time. diff --git a/states/remote/state_test.go b/states/remote/state_test.go index 413c6e75e..fb50651b0 100644 --- a/states/remote/state_test.go +++ b/states/remote/state_test.go @@ -302,6 +302,158 @@ func TestWriteStateForMigration(t *testing.T) { }, } + testCases := []migrationTestCase{ + // Refreshing state before we run the test loop causes a GET + { + name: "refresh state", + stateFile: func(mgr *State) *statefile.File { + return mgr.StateForMigration() + }, + expectedRequest: mockClientRequest{ + Method: "Get", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "mock-lineage", + "serial": 3.0, + "terraform_version": "0.0.0", + "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, + "resources": []interface{}{}, + }, + }, + }, + { + name: "cannot import lesser serial without force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, mgr.lineage, 1) + }, + expectedError: "cannot import state with serial 1 over newer state with serial 3", + }, + { + name: "cannot import differing lineage without force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, "different-lineage", mgr.serial) + }, + expectedError: `cannot import state with lineage "different-lineage" over unrelated state with lineage "mock-lineage"`, + }, + { + name: "can import lesser serial with force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, mgr.lineage, 1) + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "mock-lineage", + "serial": 2.0, + "terraform_version": version.Version, + "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, + "resources": []interface{}{}, + }, + }, + force: true, + }, + { + name: "cannot import differing lineage without force", + stateFile: func(mgr *State) *statefile.File { + return statefile.New(mgr.state, "different-lineage", mgr.serial) + }, + expectedRequest: mockClientRequest{ + Method: "Put", + Content: map[string]interface{}{ + "version": 4.0, + "lineage": "different-lineage", + "serial": 3.0, + "terraform_version": version.Version, + "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, + "resources": []interface{}{}, + }, + }, + force: true, + }, + } + + // In normal use (during a Terraform operation) we always refresh and read + // before any writes would happen, so we'll mimic that here for realism. + // NB This causes a GET to be logged so the first item in the test cases + // must account for this + if err := mgr.RefreshState(); err != nil { + t.Fatalf("failed to RefreshState: %s", err) + } + + if err := mgr.WriteState(mgr.State()); err != nil { + t.Fatalf("failed to write initial state: %s", err) + } + + // Our client is a mockClient which has a log we + // use to check that operations generate expected requests + mockClient := mgr.Client.(*mockClient) + + // logIdx tracks the current index of the log separate from + // the loop iteration so we can check operations that don't + // cause any requests to be generated + logIdx := 0 + + for _, tc := range testCases { + sf := tc.stateFile(mgr) + err := mgr.WriteStateForMigration(sf, tc.force) + shouldError := tc.expectedError != "" + + // If we are expecting and error check it and move on + if shouldError { + if err == nil { + t.Fatalf("test case %q should have failed with error %q", tc.name, tc.expectedError) + } else if err.Error() != tc.expectedError { + t.Fatalf("test case %q expected error %q but got %q", tc.name, tc.expectedError, err) + } + continue + } + + if err != nil { + t.Fatalf("test case %q failed: %v", tc.name, err) + } + + // At this point we should just do a normal write and persist + // as would happen from the CLI + mgr.WriteState(mgr.State()) + mgr.PersistState() + + if logIdx >= len(mockClient.log) { + t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) + } + loggedRequest := mockClient.log[logIdx] + logIdx++ + if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 { + t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff) + } + } + + logCnt := len(mockClient.log) + if logIdx != logCnt { + log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx) + } +} + +// This test runs the same test cases as above, but with +// a client that implements EnableForcePush -- this allows +// us to test that -force continues to work for backends without +// this interface, but that this interface works for those that do. +func TestWriteStateForMigrationWithForcePushClient(t *testing.T) { + mgr := &State{ + Client: &mockClientForcePusher{ + current: []byte(` + { + "version": 4, + "lineage": "mock-lineage", + "serial": 3, + "terraform_version":"0.0.0", + "outputs": {"foo": {"value":"bar", "type": "string"}}, + "resources": [] + } + `), + }, + } + testCases := []migrationTestCase{ // Refreshing state before we run the test loop causes a GET { @@ -385,9 +537,9 @@ func TestWriteStateForMigration(t *testing.T) { t.Fatalf("failed to write initial state: %s", err) } - // Our client is a mockClient which has a log we + // Our client is a mockClientForcePusher which has a log we // use to check that operations generate expected requests - mockClient := mgr.Client.(*mockClient) + mockClient := mgr.Client.(*mockClientForcePusher) if mockClient.force { t.Fatalf("client should not default to force")