Merge #15424: Improve robustness of state persistence handling

Previously the APIs for state persistence and management had some problematic cases where we depended on hidden mutations of the state structure as side-effects of otherwise-innocent-looking operations, which was a frequent cause of accidental regressions due to faulty assumptions.

This new model attempts to isolate certain state mutations to just within the state managers, and makes the state managers work on separated snapshots of the state rather than on the "live" object to reduce the risk of race conditions.
This commit is contained in:
Martin Atkins 2017-07-05 16:27:08 -07:00 committed by GitHub
commit 39c4d6ab1f
11 changed files with 350 additions and 383 deletions

View File

@ -45,16 +45,7 @@ func TestApply(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
@ -281,6 +272,14 @@ func TestApply_defaultState(t *testing.T) {
}, },
} }
// create an existing state file
localState := &state.LocalState{Path: statePath}
if err := localState.WriteState(terraform.NewState()); err != nil {
t.Fatal(err)
}
serial := localState.State().Serial
args := []string{ args := []string{
testFixturePath("apply"), testFixturePath("apply"),
} }
@ -292,19 +291,14 @@ func TestApply_defaultState(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
if state.Serial <= serial {
t.Fatalf("serial was not incremented. previous:%d, current%d", serial, state.Serial)
}
} }
func TestApply_error(t *testing.T) { func TestApply_error(t *testing.T) {
@ -360,16 +354,7 @@ func TestApply_error(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
@ -484,13 +469,7 @@ func TestApply_noArgs(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -538,16 +517,7 @@ func TestApply_plan(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
@ -582,19 +552,8 @@ func TestApply_plan_backup(t *testing.T) {
t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String())
} }
{ // Should have a backup file
// Should have a backup file testStateRead(t, backupPath)
f, err := os.Open(backupPath)
if err != nil {
t.Fatalf("err: %s", err)
}
_, err = terraform.ReadState(f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
} }
func TestApply_plan_noBackup(t *testing.T) { func TestApply_plan_noBackup(t *testing.T) {
@ -732,16 +691,7 @@ func TestApply_planWithVarFile(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
@ -847,31 +797,13 @@ func TestApply_refresh(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
// Should have a backup file // Should have a backup file
f, err = os.Open(statePath + DefaultBackupExtension) backupState := testStateRead(t, statePath+DefaultBackupExtension)
if err != nil {
t.Fatalf("err: %s", err)
}
backupState, err := terraform.ReadState(f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
actualStr := strings.TrimSpace(backupState.String()) actualStr := strings.TrimSpace(backupState.String())
expectedStr := strings.TrimSpace(originalState.String()) expectedStr := strings.TrimSpace(originalState.String())
@ -953,16 +885,7 @@ func TestApply_shutdown(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
@ -1035,31 +958,12 @@ func TestApply_state(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
// Should have a backup file backupState := testStateRead(t, statePath+DefaultBackupExtension)
f, err = os.Open(statePath + DefaultBackupExtension)
if err != nil {
t.Fatalf("err: %s", err)
}
backupState, err := terraform.ReadState(f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
// nil out the ConnInfo since that should not be restored // nil out the ConnInfo since that should not be restored
originalState.RootModule().Resources["test_instance.foo"].Primary.Ephemeral.ConnInfo = nil originalState.RootModule().Resources["test_instance.foo"].Primary.Ephemeral.ConnInfo = nil
@ -1142,17 +1046,7 @@ func TestApply_stateFuture(t *testing.T) {
t.Fatal("should fail") t.Fatal("should fail")
} }
f, err := os.Open(statePath) newState := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
newState, err := terraform.ReadState(f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
if !newState.Equal(originalState) { if !newState.Equal(originalState) {
t.Fatalf("bad: %#v", newState) t.Fatalf("bad: %#v", newState)
} }
@ -1422,31 +1316,12 @@ func TestApply_backup(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
// Should have a backup file backupState := testStateRead(t, backupPath)
f, err = os.Open(backupPath)
if err != nil {
t.Fatalf("err: %s", err)
}
backupState, err := terraform.ReadState(f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
actual := backupState.RootModule().Resources["test_instance.foo"] actual := backupState.RootModule().Resources["test_instance.foo"]
expected := originalState.RootModule().Resources["test_instance.foo"] expected := originalState.RootModule().Resources["test_instance.foo"]
@ -1504,22 +1379,13 @@ func TestApply_disableBackup(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
f, err := os.Open(statePath) state := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
defer f.Close()
state, err := terraform.ReadState(f)
if err != nil {
t.Fatalf("err: %s", err)
}
if state == nil { if state == nil {
t.Fatal("state should not be nil") t.Fatal("state should not be nil")
} }
// Ensure there is no backup // Ensure there is no backup
_, err = os.Stat(statePath + DefaultBackupExtension) _, err := os.Stat(statePath + DefaultBackupExtension)
if err == nil || !os.IsNotExist(err) { if err == nil || !os.IsNotExist(err) {
t.Fatalf("backup should not exist") t.Fatalf("backup should not exist")
} }

View File

@ -160,7 +160,6 @@ func testReadPlan(t *testing.T, path string) *terraform.Plan {
// testState returns a test State structure that we use for a lot of tests. // testState returns a test State structure that we use for a lot of tests.
func testState() *terraform.State { func testState() *terraform.State {
state := &terraform.State{ state := &terraform.State{
Version: 2,
Modules: []*terraform.ModuleState{ Modules: []*terraform.ModuleState{
&terraform.ModuleState{ &terraform.ModuleState{
Path: []string{"root"}, Path: []string{"root"},
@ -177,20 +176,7 @@ func testState() *terraform.State {
}, },
} }
state.Init() state.Init()
return state
// Write and read the state so that it is properly initialized. We
// do this since we didn't call the normal NewState constructor.
var buf bytes.Buffer
if err := terraform.WriteState(state, &buf); err != nil {
panic(err)
}
result, err := terraform.ReadState(&buf)
if err != nil {
panic(err)
}
return result
} }
func testStateFile(t *testing.T, s *terraform.State) string { func testStateFile(t *testing.T, s *terraform.State) string {
@ -252,9 +238,9 @@ func testStateRead(t *testing.T, path string) *terraform.State {
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
defer f.Close()
newState, err := terraform.ReadState(f) newState, err := terraform.ReadState(f)
f.Close()
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@ -10,6 +10,7 @@ import (
"testing" "testing"
"github.com/hashicorp/terraform/helper/copy" "github.com/hashicorp/terraform/helper/copy"
"github.com/hashicorp/terraform/state"
"github.com/hashicorp/terraform/terraform" "github.com/hashicorp/terraform/terraform"
"github.com/mitchellh/cli" "github.com/mitchellh/cli"
) )
@ -193,15 +194,11 @@ func TestRefresh_defaultState(t *testing.T) {
} }
statePath := filepath.Join(td, DefaultStateFilename) statePath := filepath.Join(td, DefaultStateFilename)
f, err := os.Create(statePath) localState := &state.LocalState{Path: statePath}
if err != nil { if err := localState.WriteState(originalState); err != nil {
t.Fatalf("err: %s", err) t.Fatal(err)
}
err = terraform.WriteState(originalState, f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
} }
serial := localState.State().Serial
// Change to that directory // Change to that directory
cwd, err := os.Getwd() cwd, err := os.Getwd()
@ -236,16 +233,7 @@ func TestRefresh_defaultState(t *testing.T) {
t.Fatal("refresh should be called") t.Fatal("refresh should be called")
} }
f, err = os.Open(statePath) newState := testStateRead(t, statePath)
if err != nil {
t.Fatalf("err: %s", err)
}
newState, err := terraform.ReadState(f)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
actual := newState.RootModule().Resources["test_instance.foo"].Primary actual := newState.RootModule().Resources["test_instance.foo"].Primary
expected := p.RefreshReturn expected := p.RefreshReturn
@ -254,16 +242,11 @@ func TestRefresh_defaultState(t *testing.T) {
t.Fatalf("bad:\n%#v", actual) t.Fatalf("bad:\n%#v", actual)
} }
f, err = os.Open(statePath + DefaultBackupExtension) if newState.Serial <= serial {
if err != nil { t.Fatalf("serial not incremented during refresh. previous:%d, current:%d", serial, newState.Serial)
t.Fatalf("err: %s", err)
} }
backupState, err := terraform.ReadState(f) backupState := testStateRead(t, statePath+DefaultBackupExtension)
f.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
actual = backupState.RootModule().Resources["test_instance.foo"].Primary actual = backupState.RootModule().Resources["test_instance.foo"].Primary
expected = originalState.RootModule().Resources["test_instance.foo"].Primary expected = originalState.RootModule().Resources["test_instance.foo"].Primary

View File

@ -254,6 +254,7 @@ func TestWorkspace_createWithState(t *testing.T) {
} }
newState := envState.State() newState := envState.State()
originalState.Version = newState.Version // the round-trip through the state manager implicitly populates version
if !originalState.Equal(newState) { if !originalState.Equal(newState) {
t.Fatalf("states not equal\norig: %s\nnew: %s", originalState, newState) t.Fatalf("states not equal\norig: %s\nnew: %s", originalState, newState)
} }

View File

@ -32,8 +32,18 @@ func (s *InmemState) WriteState(state *terraform.State) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
state.IncrementSerialMaybe(s.state) state = state.DeepCopy()
if s.state != nil {
state.Serial = s.state.Serial
if !s.state.MarshalEqual(state) {
state.Serial++
}
}
s.state = state s.state = state
return nil return nil
} }

View File

@ -48,8 +48,8 @@ func (s *LocalState) SetState(state *terraform.State) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.state = state s.state = state.DeepCopy()
s.readState = state s.readState = state.DeepCopy()
} }
// StateReader impl. // StateReader impl.
@ -74,7 +74,14 @@ func (s *LocalState) WriteState(state *terraform.State) error {
} }
defer s.stateFileOut.Sync() defer s.stateFileOut.Sync()
s.state = state s.state = state.DeepCopy() // don't want mutations before we actually get this written to disk
if s.readState != nil && s.state != nil {
// We don't trust callers to properly manage serials. Instead, we assume
// that a WriteState is always for the next version after what was
// most recently read.
s.state.Serial = s.readState.Serial
}
if _, err := s.stateFileOut.Seek(0, os.SEEK_SET); err != nil { if _, err := s.stateFileOut.Seek(0, os.SEEK_SET); err != nil {
return err return err
@ -88,8 +95,9 @@ func (s *LocalState) WriteState(state *terraform.State) error {
return nil return nil
} }
s.state.IncrementSerialMaybe(s.readState) if !s.state.MarshalEqual(s.readState) {
s.readState = s.state s.state.Serial++
}
if err := terraform.WriteState(s.state, s.stateFileOut); err != nil { if err := terraform.WriteState(s.state, s.stateFileOut); err != nil {
return err return err
@ -147,7 +155,7 @@ func (s *LocalState) RefreshState() error {
} }
s.state = state s.state = state
s.readState = state s.readState = s.state.DeepCopy()
return nil return nil
} }

View File

@ -2,6 +2,7 @@ package remote
import ( import (
"bytes" "bytes"
"fmt"
"sync" "sync"
"github.com/hashicorp/terraform/state" "github.com/hashicorp/terraform/state"
@ -33,7 +34,28 @@ func (s *State) WriteState(state *terraform.State) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.state = state if s.readState != nil && !state.SameLineage(s.readState) {
return fmt.Errorf("incompatible state lineage; given %s but want %s", state.Lineage, s.readState.Lineage)
}
// We create a deep copy of the state here, because the caller also has
// a reference to the given object and can potentially go on to mutate
// it after we return, but we want the snapshot at this point in time.
s.state = state.DeepCopy()
// Force our new state to have the same serial as our read state. We'll
// update this if PersistState is called later. (We don't require nor trust
// the caller to properly maintain serial for transient state objects since
// the rest of Terraform treats state as an openly mutable object.)
//
// If we have no read state then we assume we're either writing a new
// state for the first time or we're migrating a state from elsewhere,
// and in both cases we wish to retain the lineage and serial from
// the given state.
if s.readState != nil {
s.state.Serial = s.readState.Serial
}
return nil return nil
} }
@ -58,7 +80,7 @@ func (s *State) RefreshState() error {
} }
s.state = state s.state = state
s.readState = state s.readState = s.state.DeepCopy() // our states must be separate instances so we can track changes
return nil return nil
} }
@ -67,14 +89,28 @@ func (s *State) PersistState() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.state.IncrementSerialMaybe(s.readState) if !s.state.MarshalEqual(s.readState) {
// Our new state does not marshal as byte-for-byte identical to
// the old, so we need to increment the serial.
// Note that in WriteState we force the serial to match that of
// s.readState, if we have a readState.
s.state.Serial++
}
var buf bytes.Buffer var buf bytes.Buffer
if err := terraform.WriteState(s.state, &buf); err != nil { if err := terraform.WriteState(s.state, &buf); err != nil {
return err return err
} }
return s.Client.Put(buf.Bytes()) err := s.Client.Put(buf.Bytes())
if err != nil {
return err
}
// After we've successfully persisted, what we just wrote is our new
// reference state until someone calls RefreshState again.
s.readState = s.state.DeepCopy()
return nil
} }
// Lock calls the Client's Lock method if it's implemented. // Lock calls the Client's Lock method if it's implemented.

View File

@ -36,6 +36,11 @@ type State interface {
// the state here must not error. Loading the state fresh (an operation that // the state here must not error. Loading the state fresh (an operation that
// can likely error) should be implemented by RefreshState. If a state hasn't // can likely error) should be implemented by RefreshState. If a state hasn't
// been loaded yet, it is okay for State to return nil. // been loaded yet, it is okay for State to return nil.
//
// Each caller of this function must get a distinct copy of the state, and
// it must also be distinct from any instance cached inside the reader, to
// ensure that mutations of the returned state will not affect the values
// returned to other callers.
type StateReader interface { type StateReader interface {
State() *terraform.State State() *terraform.State
} }
@ -43,6 +48,15 @@ type StateReader interface {
// StateWriter is the interface that must be implemented by something that // StateWriter is the interface that must be implemented by something that
// can write a state. Writing the state can be cached or in-memory, as // can write a state. Writing the state can be cached or in-memory, as
// full persistence should be implemented by StatePersister. // full persistence should be implemented by StatePersister.
//
// Implementors that cache the state in memory _must_ take a copy of it
// before returning, since the caller may continue to modify it once
// control returns. The caller must ensure that the state instance is not
// concurrently modified _during_ the call, or behavior is undefined.
//
// If an object implements StatePersister in conjunction with StateReader
// then these methods must coordinate such that a subsequent read returns
// a copy of the most recent write, even if it has not yet been persisted.
type StateWriter interface { type StateWriter interface {
WriteState(*terraform.State) error WriteState(*terraform.State) error
} }
@ -57,6 +71,10 @@ type StateRefresher interface {
// StatePersister is implemented to truly persist a state. Whereas StateWriter // StatePersister is implemented to truly persist a state. Whereas StateWriter
// is allowed to perhaps be caching in memory, PersistState must write the // is allowed to perhaps be caching in memory, PersistState must write the
// state to some durable storage. // state to some durable storage.
//
// If an object implements StatePersister in conjunction with StateReader
// and/or StateRefresher then these methods must coordinate such that
// subsequent reads after a persist return an updated value.
type StatePersister interface { type StatePersister interface {
PersistState() error PersistState() error
} }

View File

@ -10,119 +10,126 @@ import (
// TestState is a helper for testing state implementations. It is expected // TestState is a helper for testing state implementations. It is expected
// that the given implementation is pre-loaded with the TestStateInitial // that the given implementation is pre-loaded with the TestStateInitial
// state. // state.
func TestState(t *testing.T, s interface{}) { func TestState(t *testing.T, s State) {
reader, ok := s.(StateReader) if err := s.RefreshState(); err != nil {
if !ok { t.Fatalf("err: %s", err)
t.Fatalf("must at least be a StateReader")
} }
// If it implements refresh, refresh // Check that the initial state is correct.
if rs, ok := s.(StateRefresher); ok { // These do have different Lineages, but we will replace current below.
if err := rs.RefreshState(); err != nil { initial := TestStateInitial()
t.Fatalf("err: %s", err) if state := s.State(); !state.Equal(initial) {
} t.Fatalf("state does not match expected initial state:\n%#v\n\n%#v", state, initial)
} }
// current will track our current state // Now we've proven that the state we're starting with is an initial
current := TestStateInitial() // state, we'll complete our work here with that state, since otherwise
// further writes would violate the invariant that we only try to write
// Check that the initial state is correct // states that share the same lineage as what was initially written.
if state := reader.State(); !current.Equal(state) { current := s.State()
t.Fatalf("not initial:\n%#v\n\n%#v", state, current)
}
// Write a new state and verify that we have it // Write a new state and verify that we have it
if ws, ok := s.(StateWriter); ok { current.AddModuleState(&terraform.ModuleState{
current.AddModuleState(&terraform.ModuleState{ Path: []string{"root"},
Path: []string{"root"}, Outputs: map[string]*terraform.OutputState{
Outputs: map[string]*terraform.OutputState{ "bar": &terraform.OutputState{
"bar": &terraform.OutputState{ Type: "string",
Type: "string", Sensitive: false,
Sensitive: false, Value: "baz",
Value: "baz",
},
}, },
}) },
})
if err := ws.WriteState(current); err != nil { if err := s.WriteState(current); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if actual := reader.State(); !actual.Equal(current) { if actual := s.State(); !actual.Equal(current) {
t.Fatalf("bad:\n%#v\n\n%#v", actual, current) t.Fatalf("bad:\n%#v\n\n%#v", actual, current)
}
} }
// Test persistence // Test persistence
if ps, ok := s.(StatePersister); ok { if err := s.PersistState(); err != nil {
if err := ps.PersistState(); err != nil { t.Fatalf("err: %s", err)
t.Fatalf("err: %s", err)
}
// Refresh if we got it
if rs, ok := s.(StateRefresher); ok {
if err := rs.RefreshState(); err != nil {
t.Fatalf("err: %s", err)
}
}
// Just set the serials the same... Then compare.
actual := reader.State()
if !actual.Equal(current) {
t.Fatalf("bad: %#v\n\n%#v", actual, current)
}
} }
// If we can write and persist then verify that the serial // Refresh if we got it
// is only implemented on change. if err := s.RefreshState(); err != nil {
writer, writeOk := s.(StateWriter) t.Fatalf("err: %s", err)
persister, persistOk := s.(StatePersister) }
if writeOk && persistOk {
// Same serial
serial := current.Serial
if err := writer.WriteState(current); err != nil {
t.Fatalf("err: %s", err)
}
if err := persister.PersistState(); err != nil {
t.Fatalf("err: %s", err)
}
if reader.State().Serial != serial { if s.State().Lineage != current.Lineage {
t.Fatalf("bad: expected %d, got %d", serial, reader.State().Serial) t.Fatalf("Lineage changed from %s to %s", s.State().Lineage, current.Lineage)
} }
// Change the serial // Just set the serials the same... Then compare.
current = current.DeepCopy() actual := s.State()
current.Modules = []*terraform.ModuleState{ if !actual.Equal(current) {
&terraform.ModuleState{ t.Fatalf("bad: %#v\n\n%#v", actual, current)
Path: []string{"root", "somewhere"}, }
Outputs: map[string]*terraform.OutputState{
"serialCheck": &terraform.OutputState{ // Same serial
Type: "string", serial := s.State().Serial
Sensitive: false, if err := s.WriteState(current); err != nil {
Value: "true", t.Fatalf("err: %s", err)
}, }
if err := s.PersistState(); err != nil {
t.Fatalf("err: %s", err)
}
if s.State().Serial != serial {
t.Fatalf("serial changed after persisting with no changes: got %d, want %d", s.State().Serial, serial)
}
// Change the serial
current = current.DeepCopy()
current.Modules = []*terraform.ModuleState{
&terraform.ModuleState{
Path: []string{"root", "somewhere"},
Outputs: map[string]*terraform.OutputState{
"serialCheck": &terraform.OutputState{
Type: "string",
Sensitive: false,
Value: "true",
}, },
}, },
} },
if err := writer.WriteState(current); err != nil { }
t.Fatalf("err: %s", err) if err := s.WriteState(current); err != nil {
} t.Fatalf("err: %s", err)
if err := persister.PersistState(); err != nil { }
t.Fatalf("err: %s", err) if err := s.PersistState(); err != nil {
} t.Fatalf("err: %s", err)
}
if reader.State().Serial <= serial { if s.State().Serial <= serial {
t.Fatalf("bad: expected %d, got %d", serial, reader.State().Serial) t.Fatalf("serial incorrect after persisting with changes: got %d, want > %d", s.State().Serial, serial)
} }
// Check that State() returns a copy by modifying the copy and comparing if s.State().Version != current.Version {
// to the current state. t.Fatalf("Version changed from %d to %d", s.State().Version, current.Version)
stateCopy := reader.State() }
stateCopy.Serial++
if reflect.DeepEqual(stateCopy, current) { if s.State().TFVersion != current.TFVersion {
t.Fatal("State() should return a copy") t.Fatalf("TFVersion changed from %s to %s", s.State().TFVersion, current.TFVersion)
} }
// verify that Lineage doesn't change along with Serial, or during copying.
if s.State().Lineage != current.Lineage {
t.Fatalf("Lineage changed from %s to %s", s.State().Lineage, current.Lineage)
}
// Check that State() returns a copy by modifying the copy and comparing
// to the current state.
stateCopy := s.State()
stateCopy.Serial++
if reflect.DeepEqual(stateCopy, s.State()) {
t.Fatal("State() should return a copy")
}
// our current expected state should also marhsal identically to the persisted state
if current.MarshalEqual(s.State()) {
t.Fatalf("Persisted state altered unexpectedly. Expected: %#v\b Got: %#v", current, s.State())
} }
} }

View File

@ -533,6 +533,43 @@ func (s *State) equal(other *State) bool {
return true return true
} }
// MarshalEqual is similar to Equal but provides a stronger definition of
// "equal", where two states are equal if and only if their serialized form
// is byte-for-byte identical.
//
// This is primarily useful for callers that are trying to save snapshots
// of state to persistent storage, allowing them to detect when a new
// snapshot must be taken.
//
// Note that the serial number and lineage are included in the serialized form,
// so it's the caller's responsibility to properly manage these attributes
// so that this method is only called on two states that have the same
// serial and lineage, unless detecting such differences is desired.
func (s *State) MarshalEqual(other *State) bool {
if s == nil && other == nil {
return true
} else if s == nil || other == nil {
return false
}
recvBuf := &bytes.Buffer{}
otherBuf := &bytes.Buffer{}
err := WriteState(s, recvBuf)
if err != nil {
// should never happen, since we're writing to a buffer
panic(err)
}
err = WriteState(other, otherBuf)
if err != nil {
// should never happen, since we're writing to a buffer
panic(err)
}
return bytes.Equal(recvBuf.Bytes(), otherBuf.Bytes())
}
type StateAgeComparison int type StateAgeComparison int
const ( const (
@ -603,6 +640,10 @@ func (s *State) SameLineage(other *State) bool {
// DeepCopy performs a deep copy of the state structure and returns // DeepCopy performs a deep copy of the state structure and returns
// a new structure. // a new structure.
func (s *State) DeepCopy() *State { func (s *State) DeepCopy() *State {
if s == nil {
return nil
}
copy, err := copystructure.Config{Lock: true}.Copy(s) copy, err := copystructure.Config{Lock: true}.Copy(s)
if err != nil { if err != nil {
panic(err) panic(err)
@ -611,30 +652,6 @@ func (s *State) DeepCopy() *State {
return copy.(*State) return copy.(*State)
} }
// IncrementSerialMaybe increments the serial number of this state
// if it different from the other state.
func (s *State) IncrementSerialMaybe(other *State) {
if s == nil {
return
}
if other == nil {
return
}
s.Lock()
defer s.Unlock()
if s.Serial > other.Serial {
return
}
if other.TFVersion != s.TFVersion || !s.equal(other) {
if other.Serial > s.Serial {
s.Serial = other.Serial
}
s.Serial++
}
}
// FromFutureTerraform checks if this state was written by a Terraform // FromFutureTerraform checks if this state was written by a Terraform
// version from the future. // version from the future.
func (s *State) FromFutureTerraform() bool { func (s *State) FromFutureTerraform() bool {
@ -660,6 +677,7 @@ func (s *State) init() {
if s.Version == 0 { if s.Version == 0 {
s.Version = StateVersion s.Version = StateVersion
} }
if s.moduleByPath(rootModulePath) == nil { if s.moduleByPath(rootModulePath) == nil {
s.addModule(rootModulePath) s.addModule(rootModulePath)
} }

View File

@ -631,87 +631,121 @@ func TestStateSameLineage(t *testing.T) {
} }
} }
func TestStateIncrementSerialMaybe(t *testing.T) { func TestStateMarshalEqual(t *testing.T) {
cases := map[string]struct { tests := map[string]struct {
S1, S2 *State S1, S2 *State
Serial int64 Want bool
}{ }{
"S2 is nil": { "both nil": {
nil,
nil,
true,
},
"first zero, second nil": {
&State{}, &State{},
nil, nil,
0, false,
}, },
"S2 is identical": { "first nil, second zero": {
nil,
&State{}, &State{},
&State{}, false,
0,
}, },
"S2 is different": { "both zero": {
// These are not equal because they both implicitly init with
// different lineage.
&State{}, &State{},
&State{},
false,
},
"both set, same lineage": {
&State{ &State{
Modules: []*ModuleState{ Lineage: "abc123",
&ModuleState{Path: rootModulePath},
},
}, },
1,
},
"S2 is different, but only via Instance Metadata": {
&State{ &State{
Serial: 3, Lineage: "abc123",
},
true,
},
"both set, same lineage, different serial": {
&State{
Lineage: "abc123",
Serial: 1,
},
&State{
Lineage: "abc123",
Serial: 2,
},
false,
},
"both set, same lineage, same serial, same resources": {
&State{
Lineage: "abc123",
Serial: 1,
Modules: []*ModuleState{ Modules: []*ModuleState{
&ModuleState{ {
Path: rootModulePath, Path: []string{"root"},
Resources: map[string]*ResourceState{ Resources: map[string]*ResourceState{
"test_instance.foo": &ResourceState{ "foo_bar.baz": {},
Primary: &InstanceState{
Meta: map[string]interface{}{},
},
},
}, },
}, },
}, },
}, },
&State{ &State{
Serial: 3, Lineage: "abc123",
Serial: 1,
Modules: []*ModuleState{ Modules: []*ModuleState{
&ModuleState{ {
Path: rootModulePath, Path: []string{"root"},
Resources: map[string]*ResourceState{ Resources: map[string]*ResourceState{
"test_instance.foo": &ResourceState{ "foo_bar.baz": {},
Primary: &InstanceState{
Meta: map[string]interface{}{
"schema_version": "1",
},
},
},
}, },
}, },
}, },
}, },
4, true,
}, },
"S1 serial is higher": { "both set, same lineage, same serial, different resources": {
&State{Serial: 5},
&State{ &State{
Serial: 3, Lineage: "abc123",
Serial: 1,
Modules: []*ModuleState{ Modules: []*ModuleState{
&ModuleState{Path: rootModulePath}, {
Path: []string{"root"},
Resources: map[string]*ResourceState{
"foo_bar.baz": {},
},
},
}, },
}, },
5, &State{
}, Lineage: "abc123",
"S2 has a different TFVersion": { Serial: 1,
&State{TFVersion: "0.1"}, Modules: []*ModuleState{
&State{TFVersion: "0.2"}, {
1, Path: []string{"root"},
Resources: map[string]*ResourceState{
"pizza_crust.tasty": {},
},
},
},
},
false,
}, },
} }
for name, tc := range cases { for name, test := range tests {
tc.S1.IncrementSerialMaybe(tc.S2) t.Run(name, func(t *testing.T) {
if tc.S1.Serial != tc.Serial { got := test.S1.MarshalEqual(test.S2)
t.Fatalf("Bad: %s\nGot: %d", name, tc.S1.Serial) if got != test.Want {
} t.Errorf("wrong result %#v; want %#v", got, test.Want)
s1Buf := &bytes.Buffer{}
s2Buf := &bytes.Buffer{}
_ = WriteState(test.S1, s1Buf)
_ = WriteState(test.S2, s2Buf)
t.Logf("\nState 1: %s\nState 2: %s", s1Buf.Bytes(), s2Buf.Bytes())
}
})
} }
} }