diff --git a/command/workspace_command_test.go b/command/workspace_command_test.go index 3329e8565..576dca43b 100644 --- a/command/workspace_command_test.go +++ b/command/workspace_command_test.go @@ -186,6 +186,7 @@ func TestWorkspace_createWithState(t *testing.T) { } newState := envState.State() + originalState.Version = newState.Version // the round-trip through the state manager implicitly populates version if !originalState.Equal(newState) { t.Fatalf("states not equal\norig: %s\nnew: %s", originalState, newState) } diff --git a/state/inmem.go b/state/inmem.go index 4e031896c..36fa34147 100644 --- a/state/inmem.go +++ b/state/inmem.go @@ -32,8 +32,18 @@ func (s *InmemState) WriteState(state *terraform.State) error { s.mu.Lock() defer s.mu.Unlock() - state.IncrementSerialMaybe(s.state) + state = state.DeepCopy() + + if s.state != nil { + state.Serial = s.state.Serial + + if !s.state.MarshalEqual(state) { + state.Serial++ + } + } + s.state = state + return nil } diff --git a/state/local.go b/state/local.go index 5ce02ce57..a6d17653b 100644 --- a/state/local.go +++ b/state/local.go @@ -48,8 +48,8 @@ func (s *LocalState) SetState(state *terraform.State) { s.mu.Lock() defer s.mu.Unlock() - s.state = state - s.readState = state + s.state = state.DeepCopy() + s.readState = state.DeepCopy() } // StateReader impl. @@ -74,7 +74,14 @@ func (s *LocalState) WriteState(state *terraform.State) error { } defer s.stateFileOut.Sync() - s.state = state + s.state = state.DeepCopy() // don't want mutations before we actually get this written to disk + + if s.readState != nil && s.state != nil { + // We don't trust callers to properly manage serials. Instead, we assume + // that a WriteState is always for the next version after what was + // most recently read. + s.state.Serial = s.readState.Serial + } if _, err := s.stateFileOut.Seek(0, os.SEEK_SET); err != nil { return err @@ -88,8 +95,9 @@ func (s *LocalState) WriteState(state *terraform.State) error { return nil } - s.state.IncrementSerialMaybe(s.readState) - s.readState = s.state + if !s.state.MarshalEqual(s.readState) { + s.state.Serial++ + } if err := terraform.WriteState(s.state, s.stateFileOut); err != nil { return err @@ -147,7 +155,7 @@ func (s *LocalState) RefreshState() error { } s.state = state - s.readState = state + s.readState = s.state.DeepCopy() return nil } diff --git a/state/remote/state.go b/state/remote/state.go index dccbab18a..8e157101d 100644 --- a/state/remote/state.go +++ b/state/remote/state.go @@ -2,6 +2,7 @@ package remote import ( "bytes" + "fmt" "sync" "github.com/hashicorp/terraform/state" @@ -33,7 +34,28 @@ func (s *State) WriteState(state *terraform.State) error { s.mu.Lock() defer s.mu.Unlock() - s.state = state + if s.readState != nil && !state.SameLineage(s.readState) { + return fmt.Errorf("incompatible state lineage; given %s but want %s", state.Lineage, s.readState.Lineage) + } + + // We create a deep copy of the state here, because the caller also has + // a reference to the given object and can potentially go on to mutate + // it after we return, but we want the snapshot at this point in time. + s.state = state.DeepCopy() + + // Force our new state to have the same serial as our read state. We'll + // update this if PersistState is called later. (We don't require nor trust + // the caller to properly maintain serial for transient state objects since + // the rest of Terraform treats state as an openly mutable object.) + // + // If we have no read state then we assume we're either writing a new + // state for the first time or we're migrating a state from elsewhere, + // and in both cases we wish to retain the lineage and serial from + // the given state. + if s.readState != nil { + s.state.Serial = s.readState.Serial + } + return nil } @@ -58,7 +80,7 @@ func (s *State) RefreshState() error { } s.state = state - s.readState = state + s.readState = s.state.DeepCopy() // our states must be separate instances so we can track changes return nil } @@ -67,14 +89,28 @@ func (s *State) PersistState() error { s.mu.Lock() defer s.mu.Unlock() - s.state.IncrementSerialMaybe(s.readState) + if !s.state.MarshalEqual(s.readState) { + // Our new state does not marshal as byte-for-byte identical to + // the old, so we need to increment the serial. + // Note that in WriteState we force the serial to match that of + // s.readState, if we have a readState. + s.state.Serial++ + } var buf bytes.Buffer if err := terraform.WriteState(s.state, &buf); err != nil { return err } - return s.Client.Put(buf.Bytes()) + err := s.Client.Put(buf.Bytes()) + if err != nil { + return err + } + + // After we've successfully persisted, what we just wrote is our new + // reference state until someone calls RefreshState again. + s.readState = s.state.DeepCopy() + return nil } // Lock calls the Client's Lock method if it's implemented. diff --git a/state/state.go b/state/state.go index 45163852b..293fc0daa 100644 --- a/state/state.go +++ b/state/state.go @@ -36,6 +36,11 @@ type State interface { // the state here must not error. Loading the state fresh (an operation that // can likely error) should be implemented by RefreshState. If a state hasn't // been loaded yet, it is okay for State to return nil. +// +// Each caller of this function must get a distinct copy of the state, and +// it must also be distinct from any instance cached inside the reader, to +// ensure that mutations of the returned state will not affect the values +// returned to other callers. type StateReader interface { State() *terraform.State } @@ -43,6 +48,15 @@ type StateReader interface { // StateWriter is the interface that must be implemented by something that // can write a state. Writing the state can be cached or in-memory, as // full persistence should be implemented by StatePersister. +// +// Implementors that cache the state in memory _must_ take a copy of it +// before returning, since the caller may continue to modify it once +// control returns. The caller must ensure that the state instance is not +// concurrently modified _during_ the call, or behavior is undefined. +// +// If an object implements StatePersister in conjunction with StateReader +// then these methods must coordinate such that a subsequent read returns +// a copy of the most recent write, even if it has not yet been persisted. type StateWriter interface { WriteState(*terraform.State) error } @@ -57,6 +71,10 @@ type StateRefresher interface { // StatePersister is implemented to truly persist a state. Whereas StateWriter // is allowed to perhaps be caching in memory, PersistState must write the // state to some durable storage. +// +// If an object implements StatePersister in conjunction with StateReader +// and/or StateRefresher then these methods must coordinate such that +// subsequent reads after a persist return an updated value. type StatePersister interface { PersistState() error } diff --git a/state/testing.go b/state/testing.go index 61f972ca1..be948d409 100644 --- a/state/testing.go +++ b/state/testing.go @@ -31,6 +31,12 @@ func TestState(t *testing.T, s interface{}) { t.Fatalf("not initial:\n%#v\n\n%#v", state, current) } + // Now we've proven that the state we're starting with is an initial + // state, we'll complete our work here with that state, since otherwise + // further writes would violate the invariant that we only try to write + // states that share the same lineage as what was initially written. + current = reader.State() + // Write a new state and verify that we have it if ws, ok := s.(StateWriter); ok { current.AddModuleState(&terraform.ModuleState{ @@ -74,12 +80,12 @@ func TestState(t *testing.T, s interface{}) { } // If we can write and persist then verify that the serial - // is only implemented on change. + // is only incremented on change. writer, writeOk := s.(StateWriter) persister, persistOk := s.(StatePersister) if writeOk && persistOk { // Same serial - serial := current.Serial + serial := reader.State().Serial if err := writer.WriteState(current); err != nil { t.Fatalf("err: %s", err) } @@ -88,7 +94,7 @@ func TestState(t *testing.T, s interface{}) { } if reader.State().Serial != serial { - t.Fatalf("bad: expected %d, got %d", serial, reader.State().Serial) + t.Fatalf("serial changed after persisting with no changes: got %d, want %d", reader.State().Serial, serial) } // Change the serial @@ -113,7 +119,7 @@ func TestState(t *testing.T, s interface{}) { } if reader.State().Serial <= serial { - t.Fatalf("bad: expected %d, got %d", serial, reader.State().Serial) + t.Fatalf("serial incorrect after persisting with changes: got %d, want > %d", reader.State().Serial, serial) } // Check that State() returns a copy by modifying the copy and comparing diff --git a/terraform/state.go b/terraform/state.go index 074b68245..a8fc5277b 100644 --- a/terraform/state.go +++ b/terraform/state.go @@ -533,6 +533,43 @@ func (s *State) equal(other *State) bool { return true } +// MarshalEqual is similar to Equal but provides a stronger definition of +// "equal", where two states are equal if and only if their serialized form +// is byte-for-byte identical. +// +// This is primarily useful for callers that are trying to save snapshots +// of state to persistent storage, allowing them to detect when a new +// snapshot must be taken. +// +// Note that the serial number and lineage are included in the serialized form, +// so it's the caller's responsibility to properly manage these attributes +// so that this method is only called on two states that have the same +// serial and lineage, unless detecting such differences is desired. +func (s *State) MarshalEqual(other *State) bool { + if s == nil && other == nil { + return true + } else if s == nil || other == nil { + return false + } + + recvBuf := &bytes.Buffer{} + otherBuf := &bytes.Buffer{} + + err := WriteState(s, recvBuf) + if err != nil { + // should never happen, since we're writing to a buffer + panic(err) + } + + err = WriteState(other, otherBuf) + if err != nil { + // should never happen, since we're writing to a buffer + panic(err) + } + + return bytes.Equal(recvBuf.Bytes(), otherBuf.Bytes()) +} + type StateAgeComparison int const ( @@ -603,6 +640,10 @@ func (s *State) SameLineage(other *State) bool { // DeepCopy performs a deep copy of the state structure and returns // a new structure. func (s *State) DeepCopy() *State { + if s == nil { + return nil + } + copy, err := copystructure.Config{Lock: true}.Copy(s) if err != nil { panic(err) @@ -611,30 +652,6 @@ func (s *State) DeepCopy() *State { return copy.(*State) } -// IncrementSerialMaybe increments the serial number of this state -// if it different from the other state. -func (s *State) IncrementSerialMaybe(other *State) { - if s == nil { - return - } - if other == nil { - return - } - s.Lock() - defer s.Unlock() - - if s.Serial > other.Serial { - return - } - if other.TFVersion != s.TFVersion || !s.equal(other) { - if other.Serial > s.Serial { - s.Serial = other.Serial - } - - s.Serial++ - } -} - // FromFutureTerraform checks if this state was written by a Terraform // version from the future. func (s *State) FromFutureTerraform() bool { diff --git a/terraform/state_test.go b/terraform/state_test.go index 5578f89c9..e56e04d44 100644 --- a/terraform/state_test.go +++ b/terraform/state_test.go @@ -631,87 +631,121 @@ func TestStateSameLineage(t *testing.T) { } } -func TestStateIncrementSerialMaybe(t *testing.T) { - cases := map[string]struct { +func TestStateMarshalEqual(t *testing.T) { + tests := map[string]struct { S1, S2 *State - Serial int64 + Want bool }{ - "S2 is nil": { + "both nil": { + nil, + nil, + true, + }, + "first zero, second nil": { &State{}, nil, - 0, + false, }, - "S2 is identical": { + "first nil, second zero": { + nil, &State{}, - &State{}, - 0, + false, }, - "S2 is different": { + "both zero": { + // These are not equal because they both implicitly init with + // different lineage. &State{}, + &State{}, + false, + }, + "both set, same lineage": { &State{ - Modules: []*ModuleState{ - &ModuleState{Path: rootModulePath}, - }, + Lineage: "abc123", }, - 1, - }, - "S2 is different, but only via Instance Metadata": { &State{ - Serial: 3, + Lineage: "abc123", + }, + true, + }, + "both set, same lineage, different serial": { + &State{ + Lineage: "abc123", + Serial: 1, + }, + &State{ + Lineage: "abc123", + Serial: 2, + }, + false, + }, + "both set, same lineage, same serial, same resources": { + &State{ + Lineage: "abc123", + Serial: 1, Modules: []*ModuleState{ - &ModuleState{ - Path: rootModulePath, + { + Path: []string{"root"}, Resources: map[string]*ResourceState{ - "test_instance.foo": &ResourceState{ - Primary: &InstanceState{ - Meta: map[string]interface{}{}, - }, - }, + "foo_bar.baz": {}, }, }, }, }, &State{ - Serial: 3, + Lineage: "abc123", + Serial: 1, Modules: []*ModuleState{ - &ModuleState{ - Path: rootModulePath, + { + Path: []string{"root"}, Resources: map[string]*ResourceState{ - "test_instance.foo": &ResourceState{ - Primary: &InstanceState{ - Meta: map[string]interface{}{ - "schema_version": "1", - }, - }, - }, + "foo_bar.baz": {}, }, }, }, }, - 4, + true, }, - "S1 serial is higher": { - &State{Serial: 5}, + "both set, same lineage, same serial, different resources": { &State{ - Serial: 3, + Lineage: "abc123", + Serial: 1, Modules: []*ModuleState{ - &ModuleState{Path: rootModulePath}, + { + Path: []string{"root"}, + Resources: map[string]*ResourceState{ + "foo_bar.baz": {}, + }, + }, }, }, - 5, - }, - "S2 has a different TFVersion": { - &State{TFVersion: "0.1"}, - &State{TFVersion: "0.2"}, - 1, + &State{ + Lineage: "abc123", + Serial: 1, + Modules: []*ModuleState{ + { + Path: []string{"root"}, + Resources: map[string]*ResourceState{ + "pizza_crust.tasty": {}, + }, + }, + }, + }, + false, }, } - for name, tc := range cases { - tc.S1.IncrementSerialMaybe(tc.S2) - if tc.S1.Serial != tc.Serial { - t.Fatalf("Bad: %s\nGot: %d", name, tc.S1.Serial) - } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + got := test.S1.MarshalEqual(test.S2) + if got != test.Want { + t.Errorf("wrong result %#v; want %#v", got, test.Want) + s1Buf := &bytes.Buffer{} + s2Buf := &bytes.Buffer{} + _ = WriteState(test.S1, s1Buf) + _ = WriteState(test.S2, s2Buf) + t.Logf("\nState 1: %s\nState 2: %s", s1Buf.Bytes(), s2Buf.Bytes()) + } + }) } }