Merge pull request #24696 from hashicorp/leetrout/remote-state-force-push

Add support for force pushing with the remote backend
This commit is contained in:
Pam Selle 2020-05-06 15:23:28 -04:00 committed by GitHub
commit 60b3815af4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 441 additions and 95 deletions

View File

@ -20,6 +20,7 @@ type remoteClient struct {
runID string runID string
stateUploadErr bool stateUploadErr bool
workspace *tfe.Workspace workspace *tfe.Workspace
forcePush bool
} }
// Get the remote state. // Get the remote state.
@ -69,6 +70,7 @@ func (r *remoteClient) Put(state []byte) error {
Serial: tfe.Int64(int64(stateFile.Serial)), Serial: tfe.Int64(int64(stateFile.Serial)),
MD5: tfe.String(fmt.Sprintf("%x", md5.Sum(state))), MD5: tfe.String(fmt.Sprintf("%x", md5.Sum(state))),
State: tfe.String(base64.StdEncoding.EncodeToString(state)), State: tfe.String(base64.StdEncoding.EncodeToString(state)),
Force: tfe.Bool(r.forcePush),
} }
// If we have a run ID, make sure to add it to the options // If we have a run ID, make sure to add it to the options
@ -97,6 +99,12 @@ func (r *remoteClient) Delete() error {
return nil return nil
} }
// EnableForcePush to allow the remote client to overwrite state
// by implementing remote.ClientForcePusher
func (r *remoteClient) EnableForcePush() {
r.forcePush = true
}
// Lock the remote state. // Lock the remote state.
func (r *remoteClient) Lock(info *state.LockInfo) (string, error) { func (r *remoteClient) Lock(info *state.LockInfo) (string, error) {
ctx := context.Background() ctx := context.Background()

View File

@ -15,6 +15,14 @@ type Client interface {
Delete() error Delete() error
} }
// ClientForcePusher is an optional interface that allows a remote
// state to force push by managing a flag on the client that is
// toggled on by a call to EnableForcePush.
type ClientForcePusher interface {
Client
EnableForcePush()
}
// ClientLocker is an optional interface that allows a remote state // ClientLocker is an optional interface that allows a remote state
// backend to enable state lock/unlock. // backend to enable state lock/unlock.
type ClientLocker interface { type ClientLocker interface {

View File

@ -69,6 +69,7 @@ func (c nilClient) Delete() error { return nil }
type mockClient struct { type mockClient struct {
current []byte current []byte
log []mockClientRequest log []mockClientRequest
force bool
} }
type mockClientRequest struct { type mockClientRequest struct {
@ -89,7 +90,11 @@ func (c *mockClient) Get() (*Payload, error) {
} }
func (c *mockClient) Put(data []byte) error { func (c *mockClient) Put(data []byte) error {
c.appendLog("Put", data) if c.force {
c.appendLog("Force Put", data)
} else {
c.appendLog("Put", data)
}
c.current = data c.current = data
return nil return nil
} }
@ -100,6 +105,11 @@ func (c *mockClient) Delete() error {
return nil return nil
} }
// Implements remote.ClientForcePusher
func (c *mockClient) EnableForcePush() {
c.force = true
}
func (c *mockClient) appendLog(method string, content []byte) { func (c *mockClient) appendLog(method string, content []byte) {
// For easier test assertions, we actually log the result of decoding // For easier test assertions, we actually log the result of decoding
// the content JSON rather than the raw bytes. Callers are in principle // the content JSON rather than the raw bytes. Callers are in principle

View File

@ -21,10 +21,19 @@ type State struct {
Client Client Client Client
lineage string // We track two pieces of meta data in addition to the state itself:
serial uint64 //
state, readState *states.State // lineage - the state's unique ID
disableLocks bool // serial - the monotonic counter of "versions" of the state
//
// Both of these (along with state) have a sister field
// that represents the values read in from an existing source.
// All three of these values are used to determine if the new
// state has changed from an existing state we read in.
lineage, readLineage string
serial, readSerial uint64
state, readState *states.State
disableLocks bool
} }
var _ statemgr.Full = (*State)(nil) var _ statemgr.Full = (*State)(nil)
@ -64,8 +73,15 @@ func (s *State) WriteStateForMigration(f *statefile.File, force bool) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
checkFile := statefile.New(s.state, s.lineage, s.serial) // `force` is passed down from the CLI flag and terminates here. Actual
if !force { // force pushing with the remote backend happens when Put()'ing the contents
// in the backend. If force is specified we skip verifications and hand the
// context off to the client to use when persitence operations actually take place.
c, isForcePusher := s.Client.(ClientForcePusher)
if force && isForcePusher {
c.EnableForcePush()
} else {
checkFile := statefile.New(s.state, s.lineage, s.serial)
if err := statemgr.CheckValidImport(f, checkFile); err != nil { if err := statemgr.CheckValidImport(f, checkFile); err != nil {
return err return err
} }
@ -113,7 +129,12 @@ func (s *State) refreshState() error {
s.lineage = stateFile.Lineage s.lineage = stateFile.Lineage
s.serial = stateFile.Serial s.serial = stateFile.Serial
s.state = stateFile.State s.state = stateFile.State
s.readState = s.state.DeepCopy() // our states must be separate instances so we can track changes
// Properties from the remote must be separate so we can
// track changes as lineage, serial and/or state are mutated
s.readLineage = stateFile.Lineage
s.readSerial = stateFile.Serial
s.readState = s.state.DeepCopy()
return nil return nil
} }
@ -123,8 +144,11 @@ func (s *State) PersistState() error {
defer s.mu.Unlock() defer s.mu.Unlock()
if s.readState != nil { if s.readState != nil {
if statefile.StatesMarshalEqual(s.state, s.readState) { lineageUnchanged := s.readLineage != "" && s.lineage == s.readLineage
// If the state hasn't changed at all then we have nothing to do. serialUnchanged := s.readSerial != 0 && s.serial == s.readSerial
stateUnchanged := statefile.StatesMarshalEqual(s.state, s.readState)
if stateUnchanged && lineageUnchanged && serialUnchanged {
// If the state, lineage or serial haven't changed at all then we have nothing to do.
return nil return nil
} }
s.serial++ s.serial++
@ -161,7 +185,13 @@ func (s *State) PersistState() error {
// After we've successfully persisted, what we just wrote is our new // After we've successfully persisted, what we just wrote is our new
// reference state until someone calls RefreshState again. // reference state until someone calls RefreshState again.
// We've potentially overwritten (via force) the state, lineage
// and / or serial (and serial was incremented) so we copy over all
// three fields so everything matches the new state and a subsequent
// operation would correctly detect no changes to the lineage, serial or state.
s.readState = s.state.DeepCopy() s.readState = s.state.DeepCopy()
s.readLineage = s.lineage
s.readSerial = s.serial
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package remote package remote
import ( import (
"log"
"sync" "sync"
"testing" "testing"
@ -8,6 +9,7 @@ import (
"github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty"
"github.com/hashicorp/terraform/states" "github.com/hashicorp/terraform/states"
"github.com/hashicorp/terraform/states/statefile"
"github.com/hashicorp/terraform/states/statemgr" "github.com/hashicorp/terraform/states/statemgr"
"github.com/hashicorp/terraform/version" "github.com/hashicorp/terraform/version"
) )
@ -41,12 +43,176 @@ func TestStateRace(t *testing.T) {
wg.Wait() wg.Wait()
} }
// testCase encapsulates a test state test
type testCase struct {
name string
// A function to mutate state and return a cleanup function
mutationFunc func(*State) (*states.State, func())
// The expected request to have taken place
expectedRequest mockClientRequest
// Mark this case as not having a request
noRequest bool
}
// isRequested ensures a test that is specified as not having
// a request doesn't have one by checking if a method exists
// on the expectedRequest.
func (tc testCase) isRequested(t *testing.T) bool {
hasMethod := tc.expectedRequest.Method != ""
if tc.noRequest && hasMethod {
t.Fatalf("expected no content for %q but got: %v", tc.name, tc.expectedRequest)
}
return !tc.noRequest
}
func TestStatePersist(t *testing.T) { func TestStatePersist(t *testing.T) {
testCases := []testCase{
// Refreshing state before we run the test loop causes a GET
{
name: "refresh state",
mutationFunc: func(mgr *State) (*states.State, func()) {
return mgr.State(), func() {}
},
expectedRequest: mockClientRequest{
Method: "Get",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 1.0, // encoding/json decodes this as float64 by default
"terraform_version": "0.0.0",
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
},
{
name: "change lineage",
mutationFunc: func(mgr *State) (*states.State, func()) {
originalLineage := mgr.lineage
mgr.lineage = "some-new-lineage"
return mgr.State(), func() {
mgr.lineage = originalLineage
}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "some-new-lineage",
"serial": 2.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
},
{
name: "change serial",
mutationFunc: func(mgr *State) (*states.State, func()) {
originalSerial := mgr.serial
mgr.serial++
return mgr.State(), func() {
mgr.serial = originalSerial
}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 4.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
},
{
name: "add output to state",
mutationFunc: func(mgr *State) (*states.State, func()) {
s := mgr.State()
s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false)
return s, func() {}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 3.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "bar",
},
},
"resources": []interface{}{},
},
},
},
{
name: "mutate state bar -> baz",
mutationFunc: func(mgr *State) (*states.State, func()) {
s := mgr.State()
s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false)
return s, func() {}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 4.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "baz",
},
},
"resources": []interface{}{},
},
},
},
{
name: "nothing changed",
mutationFunc: func(mgr *State) (*states.State, func()) {
s := mgr.State()
return s, func() {}
},
noRequest: true,
},
{
name: "reset serial (force push style)",
mutationFunc: func(mgr *State) (*states.State, func()) {
mgr.serial = 2
return mgr.State(), func() {}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 3.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "baz",
},
},
"resources": []interface{}{},
},
},
},
}
// Initial setup of state just to give us a fixed starting point for our
// test assertions below, or else we'd need to deal with
// random lineage.
mgr := &State{ mgr := &State{
Client: &mockClient{ Client: &mockClient{
// Initial state just to give us a fixed starting point for our
// test assertions below, or else we'd need to deal with
// random lineage.
current: []byte(` current: []byte(`
{ {
"version": 4, "version": 4,
@ -62,94 +228,218 @@ func TestStatePersist(t *testing.T) {
// In normal use (during a Terraform operation) we always refresh and read // In normal use (during a Terraform operation) we always refresh and read
// before any writes would happen, so we'll mimic that here for realism. // before any writes would happen, so we'll mimic that here for realism.
// NB This causes a GET to be logged so the first item in the test cases
// must account for this
if err := mgr.RefreshState(); err != nil { if err := mgr.RefreshState(); err != nil {
t.Fatalf("failed to RefreshState: %s", err) t.Fatalf("failed to RefreshState: %s", err)
} }
s := mgr.State()
s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false) // Our client is a mockClient which has a log we
if err := mgr.WriteState(s); err != nil { // use to check that operations generate expected requests
t.Fatalf("failed to WriteState: %s", err) mockClient := mgr.Client.(*mockClient)
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState: %s", err)
}
// Persisting the same state again should be a no-op: it doesn't fail, // logIdx tracks the current index of the log separate from
// but it ought not to appear in the client's log either. // the loop iteration so we can check operations that don't
if err := mgr.WriteState(s); err != nil { // cause any requests to be generated
t.Fatalf("failed to WriteState: %s", err) logIdx := 0
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState: %s", err)
}
// ...but if we _do_ change something in the state then we should see // Run tests in order.
// it re-persist. for _, tc := range testCases {
s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false) s, cleanup := tc.mutationFunc(mgr)
if err := mgr.WriteState(s); err != nil {
t.Fatalf("failed to WriteState: %s", err) if err := mgr.WriteState(s); err != nil {
t.Fatalf("failed to WriteState for %q: %s", tc.name, err)
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState for %q: %s", tc.name, err)
}
if tc.isRequested(t) {
// Get captured request from the mock client log
// based on the index of the current test
if logIdx >= len(mockClient.log) {
t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log))
}
loggedRequest := mockClient.log[logIdx]
logIdx++
if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 {
t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff)
}
}
cleanup()
} }
if err := mgr.PersistState(); err != nil { logCnt := len(mockClient.log)
t.Fatalf("failed to PersistState: %s", err) if logIdx != logCnt {
} log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx)
}
got := mgr.Client.(*mockClient).log }
want := []mockClientRequest{
// The initial fetch from mgr.RefreshState above. type migrationTestCase struct {
{ name string
Method: "Get", // A function to generate a statefile
Content: map[string]interface{}{ stateFile func(*State) *statefile.File
"version": 4.0, // encoding/json decodes this as float64 by default // The expected request to have taken place
"lineage": "mock-lineage", expectedRequest mockClientRequest
"serial": 1.0, // encoding/json decodes this as float64 by default // Mark this case as not having a request
"terraform_version": "0.0.0", expectedError string
"outputs": map[string]interface{}{}, // force flag passed to client
"resources": []interface{}{}, force bool
}, }
},
func TestWriteStateForMigration(t *testing.T) {
// First call to PersistState, with output "foo" set to "bar". mgr := &State{
{ Client: &mockClient{
Method: "Put", current: []byte(`
Content: map[string]interface{}{ {
"version": 4.0, "version": 4,
"lineage": "mock-lineage", "lineage": "mock-lineage",
"serial": 2.0, // serial increases because the outputs changed "serial": 3,
"terraform_version": version.Version, "terraform_version":"0.0.0",
"outputs": map[string]interface{}{ "outputs": {"foo": {"value":"bar", "type": "string"}},
"foo": map[string]interface{}{ "resources": []
"type": "string", }
"value": "bar", `),
}, },
}, }
"resources": []interface{}{},
}, testCases := []migrationTestCase{
}, // Refreshing state before we run the test loop causes a GET
{
// Second call to PersistState generates no client requests, because name: "refresh state",
// nothing changed in the state itself. stateFile: func(mgr *State) *statefile.File {
return mgr.StateForMigration()
// Third call to PersistState, with the "foo" output value updated },
// to "baz". expectedRequest: mockClientRequest{
{ Method: "Get",
Method: "Put", Content: map[string]interface{}{
Content: map[string]interface{}{ "version": 4.0,
"version": 4.0, "lineage": "mock-lineage",
"lineage": "mock-lineage", "serial": 3.0,
"serial": 3.0, // serial increases because the outputs changed "terraform_version": "0.0.0",
"terraform_version": version.Version, "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}},
"outputs": map[string]interface{}{ "resources": []interface{}{},
"foo": map[string]interface{}{ },
"type": "string", },
"value": "baz", },
}, {
}, name: "cannot import lesser serial without force",
"resources": []interface{}{}, stateFile: func(mgr *State) *statefile.File {
}, return statefile.New(mgr.state, mgr.lineage, 1)
}, },
} expectedError: "cannot import state with serial 1 over newer state with serial 3",
if diff := cmp.Diff(want, got); len(diff) > 0 { },
t.Errorf("incorrect client requests\n%s", diff) {
name: "cannot import differing lineage without force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, "different-lineage", mgr.serial)
},
expectedError: `cannot import state with lineage "different-lineage" over unrelated state with lineage "mock-lineage"`,
},
{
name: "can import lesser serial with force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, mgr.lineage, 1)
},
expectedRequest: mockClientRequest{
Method: "Force Put",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "mock-lineage",
"serial": 2.0,
"terraform_version": version.Version,
"outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}},
"resources": []interface{}{},
},
},
force: true,
},
{
name: "cannot import differing lineage without force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, "different-lineage", mgr.serial)
},
expectedRequest: mockClientRequest{
Method: "Force Put",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "different-lineage",
"serial": 3.0,
"terraform_version": version.Version,
"outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}},
"resources": []interface{}{},
},
},
force: true,
},
}
// In normal use (during a Terraform operation) we always refresh and read
// before any writes would happen, so we'll mimic that here for realism.
// NB This causes a GET to be logged so the first item in the test cases
// must account for this
if err := mgr.RefreshState(); err != nil {
t.Fatalf("failed to RefreshState: %s", err)
}
if err := mgr.WriteState(mgr.State()); err != nil {
t.Fatalf("failed to write initial state: %s", err)
}
// Our client is a mockClient which has a log we
// use to check that operations generate expected requests
mockClient := mgr.Client.(*mockClient)
if mockClient.force {
t.Fatalf("client should not default to force")
}
// logIdx tracks the current index of the log separate from
// the loop iteration so we can check operations that don't
// cause any requests to be generated
logIdx := 0
for _, tc := range testCases {
// Always reset client to not be force pushing
mockClient.force = false
sf := tc.stateFile(mgr)
err := mgr.WriteStateForMigration(sf, tc.force)
shouldError := tc.expectedError != ""
// If we are expecting and error check it and move on
if shouldError {
if err == nil {
t.Fatalf("test case %q should have failed with error %q", tc.name, tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Fatalf("test case %q expected error %q but got %q", tc.name, tc.expectedError, err)
}
continue
}
if err != nil {
t.Fatalf("test case %q failed: %v", tc.name, err)
}
if tc.force && !mockClient.force {
t.Fatalf("test case %q should have enabled force push", tc.name)
}
// At this point we should just do a normal write and persist
// as would happen from the CLI
mgr.WriteState(mgr.State())
mgr.PersistState()
if logIdx >= len(mockClient.log) {
t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log))
}
loggedRequest := mockClient.log[logIdx]
logIdx++
if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 {
t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff)
}
}
logCnt := len(mockClient.log)
if logIdx != logCnt {
log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx)
} }
} }