Add locks to all exported State methods

move implementation into private methods when needed
This commit is contained in:
James Bardin 2016-08-31 16:11:37 -04:00
parent 16e3a11da3
commit 0be90c8994
1 changed files with 70 additions and 6 deletions

View File

@ -98,8 +98,14 @@ func NewState() *State {
// the given path. If the path is "root", for example, then children
// returned might be "root.child", but not "root.child.grandchild".
func (s *State) Children(path []string) []*ModuleState {
s.Lock()
defer s.Unlock()
// TODO: test
return s.children(path)
}
func (s *State) children(path []string) []*ModuleState {
result := make([]*ModuleState, 0)
for _, m := range s.Modules {
if len(m.Path) != len(path)+1 {
@ -120,8 +126,10 @@ func (s *State) Children(path []string) []*ModuleState {
// This should be the preferred method to add module states since it
// allows us to optimize lookups later as well as control sorting.
func (s *State) AddModule(path []string) *ModuleState {
s.Lock()
defer s.Unlock()
// check if the module exists first
m := s.ModuleByPath(path)
m := s.moduleByPath(path)
if m != nil {
return m
}
@ -140,6 +148,13 @@ func (s *State) ModuleByPath(path []string) *ModuleState {
if s == nil {
return nil
}
s.Lock()
defer s.Unlock()
return s.moduleByPath(path)
}
func (s *State) moduleByPath(path []string) *ModuleState {
for _, mod := range s.Modules {
if mod.Path == nil {
panic("missing module path")
@ -155,6 +170,14 @@ func (s *State) ModuleByPath(path []string) *ModuleState {
// returning their full paths. These paths can be used with ModuleByPath
// to return the actual state.
func (s *State) ModuleOrphans(path []string, c *config.Config) [][]string {
s.Lock()
defer s.Unlock()
return s.moduleOrphans(path, c)
}
func (s *State) moduleOrphans(path []string, c *config.Config) [][]string {
// direct keeps track of what direct children we have both in our config
// and in our state. childrenKeys keeps track of what isn't an orphan.
direct := make(map[string]struct{})
@ -168,7 +191,7 @@ func (s *State) ModuleOrphans(path []string, c *config.Config) [][]string {
// Go over the direct children and find any that aren't in our keys.
var orphans [][]string
for _, m := range s.Children(path) {
for _, m := range s.children(path) {
key := m.Path[len(m.Path)-1]
// Record that we found this key as a direct child. We use this
@ -228,6 +251,8 @@ func (s *State) Empty() bool {
if s == nil {
return true
}
s.Lock()
defer s.Unlock()
return len(s.Modules) == 0
}
@ -238,6 +263,9 @@ func (s *State) IsRemote() bool {
if s == nil {
return false
}
s.Lock()
defer s.Unlock()
if s.Remote == nil {
return false
}
@ -258,6 +286,9 @@ func (s *State) IsRemote() bool {
// If this returns an error, then the user should be notified. The error
// response will include detailed information on the nature of the error.
func (s *State) Validate() error {
s.Lock()
defer s.Unlock()
var result error
// !!!! FOR DEVELOPERS !!!!
@ -295,6 +326,9 @@ func (s *State) Validate() error {
// all children as well. To check what will be deleted, use a StateFilter
// first.
func (s *State) Remove(addr ...string) error {
s.Lock()
defer s.Unlock()
// Filter out what we need to delete
filter := &StateFilter{State: s}
results, err := filter.Filter(addr...)
@ -362,7 +396,7 @@ func (s *State) removeModule(path []string, v *ModuleState) {
func (s *State) removeResource(path []string, v *ResourceState) {
// Get the module this resource lives in. If it doesn't exist, we're done.
mod := s.ModuleByPath(path)
mod := s.moduleByPath(path)
if mod == nil {
return
}
@ -420,6 +454,16 @@ func (s *State) Equal(other *State) bool {
return s == other
}
s.Lock()
defer s.Unlock()
return s.equal(other)
}
func (s *State) equal(other *State) bool {
if s == nil || other == nil {
return s == other
}
// If the versions are different, they're certainly not equal
if s.Version != other.Version {
return false
@ -431,7 +475,7 @@ func (s *State) Equal(other *State) bool {
}
for _, m := range s.Modules {
// This isn't very optimal currently but works.
otherM := other.ModuleByPath(m.Path)
otherM := other.moduleByPath(m.Path)
if otherM == nil {
return false
}
@ -466,7 +510,6 @@ const (
// An error is returned if the two states are not of the same lineage,
// in which case the integer returned has no meaning.
func (s *State) CompareAges(other *State) (StateAgeComparison, error) {
// nil states are "older" than actual states
switch {
case s != nil && other == nil:
@ -483,6 +526,9 @@ func (s *State) CompareAges(other *State) (StateAgeComparison, error) {
)
}
s.Lock()
defer s.Unlock()
switch {
case s.Serial < other.Serial:
return StateAgeReceiverOlder, nil
@ -496,6 +542,9 @@ func (s *State) CompareAges(other *State) (StateAgeComparison, error) {
// SameLineage returns true only if the state given in argument belongs
// to the same "lineage" of states as the reciever.
func (s *State) SameLineage(other *State) bool {
s.Lock()
defer s.Unlock()
// If one of the states has no lineage then it is assumed to predate
// this concept, and so we'll accept it as belonging to any lineage
// so that a lineage string can be assigned to newer versions
@ -527,10 +576,13 @@ func (s *State) IncrementSerialMaybe(other *State) {
if other == nil {
return
}
s.Lock()
defer s.Unlock()
if s.Serial > other.Serial {
return
}
if other.TFVersion != s.TFVersion || !s.Equal(other) {
if other.TFVersion != s.TFVersion || !s.equal(other) {
if other.Serial > s.Serial {
s.Serial = other.Serial
}
@ -542,6 +594,9 @@ func (s *State) IncrementSerialMaybe(other *State) {
// FromFutureTerraform checks if this state was written by a Terraform
// version from the future.
func (s *State) FromFutureTerraform() bool {
s.Lock()
defer s.Unlock()
// No TF version means it is certainly from the past
if s.TFVersion == "" {
return false
@ -552,6 +607,8 @@ func (s *State) FromFutureTerraform() bool {
}
func (s *State) Init() {
s.Lock()
defer s.Unlock()
s.init()
}
@ -574,6 +631,9 @@ func (s *State) init() {
}
func (s *State) EnsureHasLineage() {
s.Lock()
defer s.Unlock()
if s.Lineage == "" {
s.Lineage = uuid.NewV4().String()
log.Printf("[DEBUG] New state was assigned lineage %q\n", s.Lineage)
@ -585,6 +645,8 @@ func (s *State) EnsureHasLineage() {
// AddModuleState insert this module state and override any existing ModuleState
func (s *State) AddModuleState(mod *ModuleState) {
mod.init()
s.Lock()
defer s.Unlock()
for i, m := range s.Modules {
if reflect.DeepEqual(m.Path, mod.Path) {
@ -624,6 +686,8 @@ func (s *State) String() string {
if s == nil {
return "<nil>"
}
s.Lock()
defer s.Unlock()
var buf bytes.Buffer
for _, m := range s.Modules {