add mutexes to remote.State

This commit is contained in:
James Bardin 2017-05-25 11:01:25 -04:00
parent 9e4c0ff2ad
commit f0f2220abb
2 changed files with 43 additions and 0 deletions

View File

@ -2,6 +2,7 @@ package remote
import ( import (
"bytes" "bytes"
"sync"
"github.com/hashicorp/terraform/state" "github.com/hashicorp/terraform/state"
"github.com/hashicorp/terraform/terraform" "github.com/hashicorp/terraform/terraform"
@ -12,6 +13,8 @@ import (
// local caching so every persist will go to the remote storage and local // local caching so every persist will go to the remote storage and local
// writes will go to memory. // writes will go to memory.
type State struct { type State struct {
mu sync.Mutex
Client Client Client Client
state, readState *terraform.State state, readState *terraform.State
@ -19,17 +22,26 @@ type State struct {
// StateReader impl. // StateReader impl.
func (s *State) State() *terraform.State { func (s *State) State() *terraform.State {
s.mu.Lock()
defer s.mu.Unlock()
return s.state.DeepCopy() return s.state.DeepCopy()
} }
// StateWriter impl. // StateWriter impl.
func (s *State) WriteState(state *terraform.State) error { func (s *State) WriteState(state *terraform.State) error {
s.mu.Lock()
defer s.mu.Unlock()
s.state = state s.state = state
return nil return nil
} }
// StateRefresher impl. // StateRefresher impl.
func (s *State) RefreshState() error { func (s *State) RefreshState() error {
s.mu.Lock()
defer s.mu.Unlock()
payload, err := s.Client.Get() payload, err := s.Client.Get()
if err != nil { if err != nil {
return err return err
@ -52,6 +64,9 @@ func (s *State) RefreshState() error {
// StatePersister impl. // StatePersister impl.
func (s *State) PersistState() error { func (s *State) PersistState() error {
s.mu.Lock()
defer s.mu.Unlock()
s.state.IncrementSerialMaybe(s.readState) s.state.IncrementSerialMaybe(s.readState)
var buf bytes.Buffer var buf bytes.Buffer
@ -64,6 +79,9 @@ func (s *State) PersistState() error {
// Lock calls the Client's Lock method if it's implemented. // Lock calls the Client's Lock method if it's implemented.
func (s *State) Lock(info *state.LockInfo) (string, error) { func (s *State) Lock(info *state.LockInfo) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if c, ok := s.Client.(ClientLocker); ok { if c, ok := s.Client.(ClientLocker); ok {
return c.Lock(info) return c.Lock(info)
} }
@ -72,6 +90,9 @@ func (s *State) Lock(info *state.LockInfo) (string, error) {
// Unlock calls the Client's Unlock method if it's implemented. // Unlock calls the Client's Unlock method if it's implemented.
func (s *State) Unlock(id string) error { func (s *State) Unlock(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
if c, ok := s.Client.(ClientLocker); ok { if c, ok := s.Client.(ClientLocker); ok {
return c.Unlock(id) return c.Unlock(id)
} }

View File

@ -1,6 +1,7 @@
package remote package remote
import ( import (
"sync"
"testing" "testing"
"github.com/hashicorp/terraform/state" "github.com/hashicorp/terraform/state"
@ -13,3 +14,24 @@ func TestState_impl(t *testing.T) {
var _ state.StateRefresher = new(State) var _ state.StateRefresher = new(State)
var _ state.Locker = new(State) var _ state.Locker = new(State)
} }
func TestStateRace(t *testing.T) {
s := &State{
Client: nilClient{},
}
current := state.TestStateInitial()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.WriteState(current)
s.PersistState()
s.RefreshState()
}()
}
wg.Wait()
}