Merge pull request #24696 from hashicorp/leetrout/remote-state-force-push

Add support for force pushing with the remote backend
This commit is contained in:
Pam Selle 2020-05-06 15:23:28 -04:00 committed by GitHub
commit 60b3815af4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 441 additions and 95 deletions

View File

@ -20,6 +20,7 @@ type remoteClient struct {
runID string
stateUploadErr bool
workspace *tfe.Workspace
forcePush bool
}
// Get the remote state.
@ -69,6 +70,7 @@ func (r *remoteClient) Put(state []byte) error {
Serial: tfe.Int64(int64(stateFile.Serial)),
MD5: tfe.String(fmt.Sprintf("%x", md5.Sum(state))),
State: tfe.String(base64.StdEncoding.EncodeToString(state)),
Force: tfe.Bool(r.forcePush),
}
// If we have a run ID, make sure to add it to the options
@ -97,6 +99,12 @@ func (r *remoteClient) Delete() error {
return nil
}
// EnableForcePush to allow the remote client to overwrite state
// by implementing remote.ClientForcePusher
func (r *remoteClient) EnableForcePush() {
r.forcePush = true
}
// Lock the remote state.
func (r *remoteClient) Lock(info *state.LockInfo) (string, error) {
ctx := context.Background()

View File

@ -15,6 +15,14 @@ type Client interface {
Delete() error
}
// ClientForcePusher is an optional interface that allows a remote
// state to force push by managing a flag on the client that is
// toggled on by a call to EnableForcePush.
type ClientForcePusher interface {
Client
EnableForcePush()
}
// ClientLocker is an optional interface that allows a remote state
// backend to enable state lock/unlock.
type ClientLocker interface {

View File

@ -69,6 +69,7 @@ func (c nilClient) Delete() error { return nil }
type mockClient struct {
current []byte
log []mockClientRequest
force bool
}
type mockClientRequest struct {
@ -89,7 +90,11 @@ func (c *mockClient) Get() (*Payload, error) {
}
func (c *mockClient) Put(data []byte) error {
c.appendLog("Put", data)
if c.force {
c.appendLog("Force Put", data)
} else {
c.appendLog("Put", data)
}
c.current = data
return nil
}
@ -100,6 +105,11 @@ func (c *mockClient) Delete() error {
return nil
}
// Implements remote.ClientForcePusher
func (c *mockClient) EnableForcePush() {
c.force = true
}
func (c *mockClient) appendLog(method string, content []byte) {
// For easier test assertions, we actually log the result of decoding
// the content JSON rather than the raw bytes. Callers are in principle

View File

@ -21,10 +21,19 @@ type State struct {
Client Client
lineage string
serial uint64
state, readState *states.State
disableLocks bool
// We track two pieces of meta data in addition to the state itself:
//
// lineage - the state's unique ID
// serial - the monotonic counter of "versions" of the state
//
// Both of these (along with state) have a sister field
// that represents the values read in from an existing source.
// All three of these values are used to determine if the new
// state has changed from an existing state we read in.
lineage, readLineage string
serial, readSerial uint64
state, readState *states.State
disableLocks bool
}
var _ statemgr.Full = (*State)(nil)
@ -64,8 +73,15 @@ func (s *State) WriteStateForMigration(f *statefile.File, force bool) error {
s.mu.Lock()
defer s.mu.Unlock()
checkFile := statefile.New(s.state, s.lineage, s.serial)
if !force {
// `force` is passed down from the CLI flag and terminates here. Actual
// force pushing with the remote backend happens when Put()'ing the contents
// in the backend. If force is specified we skip verifications and hand the
// context off to the client to use when persitence operations actually take place.
c, isForcePusher := s.Client.(ClientForcePusher)
if force && isForcePusher {
c.EnableForcePush()
} else {
checkFile := statefile.New(s.state, s.lineage, s.serial)
if err := statemgr.CheckValidImport(f, checkFile); err != nil {
return err
}
@ -113,7 +129,12 @@ func (s *State) refreshState() error {
s.lineage = stateFile.Lineage
s.serial = stateFile.Serial
s.state = stateFile.State
s.readState = s.state.DeepCopy() // our states must be separate instances so we can track changes
// Properties from the remote must be separate so we can
// track changes as lineage, serial and/or state are mutated
s.readLineage = stateFile.Lineage
s.readSerial = stateFile.Serial
s.readState = s.state.DeepCopy()
return nil
}
@ -123,8 +144,11 @@ func (s *State) PersistState() error {
defer s.mu.Unlock()
if s.readState != nil {
if statefile.StatesMarshalEqual(s.state, s.readState) {
// If the state hasn't changed at all then we have nothing to do.
lineageUnchanged := s.readLineage != "" && s.lineage == s.readLineage
serialUnchanged := s.readSerial != 0 && s.serial == s.readSerial
stateUnchanged := statefile.StatesMarshalEqual(s.state, s.readState)
if stateUnchanged && lineageUnchanged && serialUnchanged {
// If the state, lineage or serial haven't changed at all then we have nothing to do.
return nil
}
s.serial++
@ -161,7 +185,13 @@ func (s *State) PersistState() error {
// After we've successfully persisted, what we just wrote is our new
// reference state until someone calls RefreshState again.
// We've potentially overwritten (via force) the state, lineage
// and / or serial (and serial was incremented) so we copy over all
// three fields so everything matches the new state and a subsequent
// operation would correctly detect no changes to the lineage, serial or state.
s.readState = s.state.DeepCopy()
s.readLineage = s.lineage
s.readSerial = s.serial
return nil
}

View File

@ -1,6 +1,7 @@
package remote
import (
"log"
"sync"
"testing"
@ -8,6 +9,7 @@ import (
"github.com/zclconf/go-cty/cty"
"github.com/hashicorp/terraform/states"
"github.com/hashicorp/terraform/states/statefile"
"github.com/hashicorp/terraform/states/statemgr"
"github.com/hashicorp/terraform/version"
)
@ -41,12 +43,176 @@ func TestStateRace(t *testing.T) {
wg.Wait()
}
// testCase encapsulates a test state test
type testCase struct {
name string
// A function to mutate state and return a cleanup function
mutationFunc func(*State) (*states.State, func())
// The expected request to have taken place
expectedRequest mockClientRequest
// Mark this case as not having a request
noRequest bool
}
// isRequested ensures a test that is specified as not having
// a request doesn't have one by checking if a method exists
// on the expectedRequest.
func (tc testCase) isRequested(t *testing.T) bool {
hasMethod := tc.expectedRequest.Method != ""
if tc.noRequest && hasMethod {
t.Fatalf("expected no content for %q but got: %v", tc.name, tc.expectedRequest)
}
return !tc.noRequest
}
func TestStatePersist(t *testing.T) {
testCases := []testCase{
// Refreshing state before we run the test loop causes a GET
{
name: "refresh state",
mutationFunc: func(mgr *State) (*states.State, func()) {
return mgr.State(), func() {}
},
expectedRequest: mockClientRequest{
Method: "Get",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 1.0, // encoding/json decodes this as float64 by default
"terraform_version": "0.0.0",
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
},
{
name: "change lineage",
mutationFunc: func(mgr *State) (*states.State, func()) {
originalLineage := mgr.lineage
mgr.lineage = "some-new-lineage"
return mgr.State(), func() {
mgr.lineage = originalLineage
}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "some-new-lineage",
"serial": 2.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
},
{
name: "change serial",
mutationFunc: func(mgr *State) (*states.State, func()) {
originalSerial := mgr.serial
mgr.serial++
return mgr.State(), func() {
mgr.serial = originalSerial
}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 4.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
},
{
name: "add output to state",
mutationFunc: func(mgr *State) (*states.State, func()) {
s := mgr.State()
s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false)
return s, func() {}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 3.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "bar",
},
},
"resources": []interface{}{},
},
},
},
{
name: "mutate state bar -> baz",
mutationFunc: func(mgr *State) (*states.State, func()) {
s := mgr.State()
s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false)
return s, func() {}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 4.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "baz",
},
},
"resources": []interface{}{},
},
},
},
{
name: "nothing changed",
mutationFunc: func(mgr *State) (*states.State, func()) {
s := mgr.State()
return s, func() {}
},
noRequest: true,
},
{
name: "reset serial (force push style)",
mutationFunc: func(mgr *State) (*states.State, func()) {
mgr.serial = 2
return mgr.State(), func() {}
},
expectedRequest: mockClientRequest{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 3.0, // encoding/json decodes this as float64 by default
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "baz",
},
},
"resources": []interface{}{},
},
},
},
}
// Initial setup of state just to give us a fixed starting point for our
// test assertions below, or else we'd need to deal with
// random lineage.
mgr := &State{
Client: &mockClient{
// Initial state just to give us a fixed starting point for our
// test assertions below, or else we'd need to deal with
// random lineage.
current: []byte(`
{
"version": 4,
@ -62,94 +228,218 @@ func TestStatePersist(t *testing.T) {
// In normal use (during a Terraform operation) we always refresh and read
// before any writes would happen, so we'll mimic that here for realism.
// NB This causes a GET to be logged so the first item in the test cases
// must account for this
if err := mgr.RefreshState(); err != nil {
t.Fatalf("failed to RefreshState: %s", err)
}
s := mgr.State()
s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false)
if err := mgr.WriteState(s); err != nil {
t.Fatalf("failed to WriteState: %s", err)
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState: %s", err)
}
// Our client is a mockClient which has a log we
// use to check that operations generate expected requests
mockClient := mgr.Client.(*mockClient)
// Persisting the same state again should be a no-op: it doesn't fail,
// but it ought not to appear in the client's log either.
if err := mgr.WriteState(s); err != nil {
t.Fatalf("failed to WriteState: %s", err)
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState: %s", err)
}
// logIdx tracks the current index of the log separate from
// the loop iteration so we can check operations that don't
// cause any requests to be generated
logIdx := 0
// ...but if we _do_ change something in the state then we should see
// it re-persist.
s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false)
if err := mgr.WriteState(s); err != nil {
t.Fatalf("failed to WriteState: %s", err)
// Run tests in order.
for _, tc := range testCases {
s, cleanup := tc.mutationFunc(mgr)
if err := mgr.WriteState(s); err != nil {
t.Fatalf("failed to WriteState for %q: %s", tc.name, err)
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState for %q: %s", tc.name, err)
}
if tc.isRequested(t) {
// Get captured request from the mock client log
// based on the index of the current test
if logIdx >= len(mockClient.log) {
t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log))
}
loggedRequest := mockClient.log[logIdx]
logIdx++
if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 {
t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff)
}
}
cleanup()
}
if err := mgr.PersistState(); err != nil {
t.Fatalf("failed to PersistState: %s", err)
}
got := mgr.Client.(*mockClient).log
want := []mockClientRequest{
// The initial fetch from mgr.RefreshState above.
{
Method: "Get",
Content: map[string]interface{}{
"version": 4.0, // encoding/json decodes this as float64 by default
"lineage": "mock-lineage",
"serial": 1.0, // encoding/json decodes this as float64 by default
"terraform_version": "0.0.0",
"outputs": map[string]interface{}{},
"resources": []interface{}{},
},
},
// First call to PersistState, with output "foo" set to "bar".
{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "mock-lineage",
"serial": 2.0, // serial increases because the outputs changed
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "bar",
},
},
"resources": []interface{}{},
},
},
// Second call to PersistState generates no client requests, because
// nothing changed in the state itself.
// Third call to PersistState, with the "foo" output value updated
// to "baz".
{
Method: "Put",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "mock-lineage",
"serial": 3.0, // serial increases because the outputs changed
"terraform_version": version.Version,
"outputs": map[string]interface{}{
"foo": map[string]interface{}{
"type": "string",
"value": "baz",
},
},
"resources": []interface{}{},
},
},
}
if diff := cmp.Diff(want, got); len(diff) > 0 {
t.Errorf("incorrect client requests\n%s", diff)
logCnt := len(mockClient.log)
if logIdx != logCnt {
log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx)
}
}
type migrationTestCase struct {
name string
// A function to generate a statefile
stateFile func(*State) *statefile.File
// The expected request to have taken place
expectedRequest mockClientRequest
// Mark this case as not having a request
expectedError string
// force flag passed to client
force bool
}
func TestWriteStateForMigration(t *testing.T) {
mgr := &State{
Client: &mockClient{
current: []byte(`
{
"version": 4,
"lineage": "mock-lineage",
"serial": 3,
"terraform_version":"0.0.0",
"outputs": {"foo": {"value":"bar", "type": "string"}},
"resources": []
}
`),
},
}
testCases := []migrationTestCase{
// Refreshing state before we run the test loop causes a GET
{
name: "refresh state",
stateFile: func(mgr *State) *statefile.File {
return mgr.StateForMigration()
},
expectedRequest: mockClientRequest{
Method: "Get",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "mock-lineage",
"serial": 3.0,
"terraform_version": "0.0.0",
"outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}},
"resources": []interface{}{},
},
},
},
{
name: "cannot import lesser serial without force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, mgr.lineage, 1)
},
expectedError: "cannot import state with serial 1 over newer state with serial 3",
},
{
name: "cannot import differing lineage without force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, "different-lineage", mgr.serial)
},
expectedError: `cannot import state with lineage "different-lineage" over unrelated state with lineage "mock-lineage"`,
},
{
name: "can import lesser serial with force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, mgr.lineage, 1)
},
expectedRequest: mockClientRequest{
Method: "Force Put",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "mock-lineage",
"serial": 2.0,
"terraform_version": version.Version,
"outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}},
"resources": []interface{}{},
},
},
force: true,
},
{
name: "cannot import differing lineage without force",
stateFile: func(mgr *State) *statefile.File {
return statefile.New(mgr.state, "different-lineage", mgr.serial)
},
expectedRequest: mockClientRequest{
Method: "Force Put",
Content: map[string]interface{}{
"version": 4.0,
"lineage": "different-lineage",
"serial": 3.0,
"terraform_version": version.Version,
"outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}},
"resources": []interface{}{},
},
},
force: true,
},
}
// In normal use (during a Terraform operation) we always refresh and read
// before any writes would happen, so we'll mimic that here for realism.
// NB This causes a GET to be logged so the first item in the test cases
// must account for this
if err := mgr.RefreshState(); err != nil {
t.Fatalf("failed to RefreshState: %s", err)
}
if err := mgr.WriteState(mgr.State()); err != nil {
t.Fatalf("failed to write initial state: %s", err)
}
// Our client is a mockClient which has a log we
// use to check that operations generate expected requests
mockClient := mgr.Client.(*mockClient)
if mockClient.force {
t.Fatalf("client should not default to force")
}
// logIdx tracks the current index of the log separate from
// the loop iteration so we can check operations that don't
// cause any requests to be generated
logIdx := 0
for _, tc := range testCases {
// Always reset client to not be force pushing
mockClient.force = false
sf := tc.stateFile(mgr)
err := mgr.WriteStateForMigration(sf, tc.force)
shouldError := tc.expectedError != ""
// If we are expecting and error check it and move on
if shouldError {
if err == nil {
t.Fatalf("test case %q should have failed with error %q", tc.name, tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Fatalf("test case %q expected error %q but got %q", tc.name, tc.expectedError, err)
}
continue
}
if err != nil {
t.Fatalf("test case %q failed: %v", tc.name, err)
}
if tc.force && !mockClient.force {
t.Fatalf("test case %q should have enabled force push", tc.name)
}
// At this point we should just do a normal write and persist
// as would happen from the CLI
mgr.WriteState(mgr.State())
mgr.PersistState()
if logIdx >= len(mockClient.log) {
t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log))
}
loggedRequest := mockClient.log[logIdx]
logIdx++
if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 {
t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff)
}
}
logCnt := len(mockClient.log)
if logIdx != logCnt {
log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx)
}
}