diff --git a/terraform/context.go b/terraform/context.go index 39e37129a..eb12e102d 100644 --- a/terraform/context.go +++ b/terraform/context.go @@ -893,19 +893,23 @@ func (c *Context) planDestroyWalkFn(result *Plan) depgraph.WalkFunc { } func (c *Context) refreshWalkFn() depgraph.WalkFunc { - cb := func(r *Resource, tainted bool, inst **InstanceState) error { - if *inst == nil || (*inst).ID == "" { + cb := func(r *Resource) error { + is := r.State.Primary + if r.Tainted { + is = r.State.Tainted[r.TaintedIndex] + } + + if is == nil || is.ID == "" { log.Printf("[DEBUG] %s: Not refreshing, ID is empty", r.Id) return nil } - state := *inst for _, h := range c.hooks { - handleHook(h.PreRefresh(r.Id, state)) + handleHook(h.PreRefresh(r.Id, is)) } info := &InstanceInfo{Type: r.State.Type} - is, err := r.Provider.Refresh(info, state) + is, err := r.Provider.Refresh(info, is) if err != nil { return err } @@ -914,8 +918,11 @@ func (c *Context) refreshWalkFn() depgraph.WalkFunc { is.init() } - // Update the state - *inst = is + if r.Tainted { + r.State.Tainted[r.TaintedIndex] = is + } else { + r.State.Primary = is + } // TODO: Handle other modules c.sl.Lock() @@ -930,7 +937,7 @@ func (c *Context) refreshWalkFn() depgraph.WalkFunc { return nil } - return c.genericWalkFn(instanceWalk(cb)) + return c.genericWalkFn(cb) } func (c *Context) validateWalkFn(rws *[]string, res *[]error) depgraph.WalkFunc { diff --git a/terraform/context_test.go b/terraform/context_test.go index e155bef17..d5858d64c 100644 --- a/terraform/context_test.go +++ b/terraform/context_test.go @@ -2235,6 +2235,7 @@ func TestContextRefresh_tainted(t *testing.T) { Path: rootModulePath, Resources: map[string]*ResourceState{ "aws_instance.web": &ResourceState{ + Type: "aws_instance", Tainted: []*InstanceState{ &InstanceState{ ID: "bar", @@ -2262,16 +2263,14 @@ func TestContextRefresh_tainted(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - originalMod := state.RootModule() - mod := s.RootModule() if !p.RefreshCalled { t.Fatal("refresh should be called") } - if !reflect.DeepEqual(p.RefreshState, originalMod.Resources["aws_instance.web"].Tainted[0]) { - t.Fatalf("bad: %#v %#v", p.RefreshState, originalMod.Resources["aws_instance.web"].Tainted[0]) - } - if !reflect.DeepEqual(mod.Resources["aws_instance.web"].Tainted[0], p.RefreshReturn) { - t.Fatalf("bad: %#v", mod.Resources) + + actual := strings.TrimSpace(s.String()) + expected := strings.TrimSpace(testContextRefreshTaintedStr) + if actual != expected { + t.Fatalf("bad:\n\n%s\n\n%s", actual, expected) } } @@ -2476,3 +2475,9 @@ root root -> aws_instance.bar root -> aws_instance.foo ` + +const testContextRefreshTaintedStr = ` +aws_instance.web: (1 tainted) + ID = + Tainted ID 1 = foo +` diff --git a/terraform/state.go b/terraform/state.go index 18cfa6200..b0cd16c54 100644 --- a/terraform/state.go +++ b/terraform/state.go @@ -364,6 +364,12 @@ func (s *ResourceState) GoString() string { return fmt.Sprintf("*%#v", *s) } +func (s *ResourceState) String() string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("Type = %s", s.Type)) + return buf.String() +} + // InstanceState is used to track the unique state information belonging // to a given instance. type InstanceState struct { @@ -451,7 +457,7 @@ func (i *InstanceState) GoString() string { func (i *InstanceState) String() string { var buf bytes.Buffer - if i.ID == "" { + if i == nil || i.ID == "" { return "" }