From 5e0765c24a1c44acfa311e6e4d4db7370ab337cf Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Fri, 19 Sep 2014 16:24:17 -0700 Subject: [PATCH] terraform: Refresh handles tainted instances --- terraform/context.go | 56 ++++++++++++++++++++++++++++----------- terraform/context_test.go | 53 ++++++++++++++++++++++++++++++++++-- terraform/state.go | 2 +- 3 files changed, 93 insertions(+), 18 deletions(-) diff --git a/terraform/context.go b/terraform/context.go index 2fac4d1a1..39e37129a 100644 --- a/terraform/context.go +++ b/terraform/context.go @@ -17,6 +17,11 @@ import ( // tree internally on the Terraform structure. type genericWalkFunc func(*Resource) error +// This function is used to implement a walked for a resource that +// visits each instance, handling tainted resources first, then the +// primary. +type instanceWalkFunc func(*Resource, bool, **InstanceState) error + // Context represents all the context that Terraform needs in order to // perform operations on infrastructure. This structure is built using // ContextOpts and NewContext. See the documentation for those. @@ -231,6 +236,9 @@ func (c *Context) Refresh() (*State, error) { v := c.acquireRun() defer c.releaseRun(v) + // Update our state + c.state = c.state.deepcopy() + g, err := Graph(&GraphOpts{ Config: c.config, Providers: c.providers, @@ -241,10 +249,10 @@ func (c *Context) Refresh() (*State, error) { return c.state, err } - // Update our state - c.state = c.state.deepcopy() - err = g.Walk(c.refreshWalkFn()) + + // Prune the state + c.state.prune() return c.state, err } @@ -885,18 +893,19 @@ func (c *Context) planDestroyWalkFn(result *Plan) depgraph.WalkFunc { } func (c *Context) refreshWalkFn() depgraph.WalkFunc { - cb := func(r *Resource) error { - if r.State.Primary == nil || r.State.Primary.ID == "" { + cb := func(r *Resource, tainted bool, inst **InstanceState) error { + if *inst == nil || (*inst).ID == "" { log.Printf("[DEBUG] %s: Not refreshing, ID is empty", r.Id) return nil } + state := *inst for _, h := range c.hooks { - handleHook(h.PreRefresh(r.Id, r.State.Primary)) + handleHook(h.PreRefresh(r.Id, state)) } info := &InstanceInfo{Type: r.State.Type} - is, err := r.Provider.Refresh(info, r.State.Primary) + is, err := r.Provider.Refresh(info, state) if err != nil { return err } @@ -905,24 +914,23 @@ func (c *Context) refreshWalkFn() depgraph.WalkFunc { is.init() } - c.sl.Lock() + // Update the state + *inst = is + // TODO: Handle other modules + c.sl.Lock() mod := c.state.RootModule() - if len(r.State.Tainted) == 0 && (is == nil || is.ID == "") { - delete(mod.Resources, r.Id) - } else { - mod.Resources[r.Id].Primary = is - } + mod.Resources[r.Id] = r.State c.sl.Unlock() for _, h := range c.hooks { - handleHook(h.PostRefresh(r.Id, r.State.Primary)) + handleHook(h.PostRefresh(r.Id, is)) } return nil } - return c.genericWalkFn(cb) + return c.genericWalkFn(instanceWalk(cb)) } func (c *Context) validateWalkFn(rws *[]string, res *[]error) depgraph.WalkFunc { @@ -1009,6 +1017,24 @@ func (c *Context) validateWalkFn(rws *[]string, res *[]error) depgraph.WalkFunc } } +//type instanceWalkFunc func(*Resource, bool, **InstanceState) error +func instanceWalk(cb instanceWalkFunc) genericWalkFunc { + return func(r *Resource) error { + // Handle the tainted resources first + for idx := range r.State.Tainted { + if err := cb(r, true, &r.State.Tainted[idx]); err != nil { + return err + } + } + + // Handle the primary resource + if r.State.Primary == nil { + r.State.init() + } + return cb(r, false, &r.State.Primary) + } +} + func (c *Context) genericWalkFn(cb genericWalkFunc) depgraph.WalkFunc { // This will keep track of whether we're stopped or not var stop uint32 = 0 diff --git a/terraform/context_test.go b/terraform/context_test.go index 8319f0da3..e155bef17 100644 --- a/terraform/context_test.go +++ b/terraform/context_test.go @@ -2055,7 +2055,7 @@ func TestContextRefresh(t *testing.T) { t.Fatalf("bad: %#v", p.RefreshState) } if !reflect.DeepEqual(mod.Resources["aws_instance.web"].Primary, p.RefreshReturn) { - t.Fatalf("bad: %#v", mod.Resources["aws_instance.web"]) + t.Fatalf("bad: %#v %#v", mod.Resources["aws_instance.web"], p.RefreshReturn) } for _, r := range mod.Resources { @@ -2219,13 +2219,62 @@ func TestContextRefresh_state(t *testing.T) { t.Fatal("refresh should be called") } if !reflect.DeepEqual(p.RefreshState, originalMod.Resources["aws_instance.web"].Primary) { - t.Fatalf("bad: %#v", p.RefreshState) + t.Fatalf("bad: %#v %#v", p.RefreshState, originalMod.Resources["aws_instance.web"].Primary) } if !reflect.DeepEqual(mod.Resources["aws_instance.web"].Primary, p.RefreshReturn) { t.Fatalf("bad: %#v", mod.Resources) } } +func TestContextRefresh_tainted(t *testing.T) { + p := testProvider("aws") + c := testConfig(t, "refresh-basic") + state := &State{ + Modules: []*ModuleState{ + &ModuleState{ + Path: rootModulePath, + Resources: map[string]*ResourceState{ + "aws_instance.web": &ResourceState{ + Tainted: []*InstanceState{ + &InstanceState{ + ID: "bar", + }, + }, + }, + }, + }, + }, + } + ctx := testContext(t, &ContextOpts{ + Config: c, + Providers: map[string]ResourceProviderFactory{ + "aws": testProviderFuncFixed(p), + }, + State: state, + }) + + p.RefreshFn = nil + p.RefreshReturn = &InstanceState{ + ID: "foo", + } + + s, err := ctx.Refresh() + if err != nil { + t.Fatalf("err: %s", err) + } + originalMod := state.RootModule() + mod := s.RootModule() + if !p.RefreshCalled { + t.Fatal("refresh should be called") + } + if !reflect.DeepEqual(p.RefreshState, originalMod.Resources["aws_instance.web"].Tainted[0]) { + t.Fatalf("bad: %#v %#v", p.RefreshState, originalMod.Resources["aws_instance.web"].Tainted[0]) + } + if !reflect.DeepEqual(mod.Resources["aws_instance.web"].Tainted[0], p.RefreshReturn) { + t.Fatalf("bad: %#v", mod.Resources) + } +} + func TestContextRefresh_vars(t *testing.T) { p := testProvider("aws") c := testConfig(t, "refresh-vars") diff --git a/terraform/state.go b/terraform/state.go index 327fd8351..18cfa6200 100644 --- a/terraform/state.go +++ b/terraform/state.go @@ -418,8 +418,8 @@ func (s *InstanceState) MergeDiff(d *InstanceDiff) *InstanceState { result := s.deepcopy() if result == nil { result = new(InstanceState) - result.init() } + result.init() if s != nil { for k, v := range s.Attributes {