From 928fce71f72c3821508b935d166504a04d91dbd6 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Thu, 19 Jan 2017 18:10:21 -0800 Subject: [PATCH] config: parse "when" and "on_failure" on provisioners --- config/config.go | 18 ++++++++- config/config_string.go | 11 ++++- config/config_test.go | 7 ++++ config/loader_hcl.go | 36 ++++++++++++++++- config/loader_test.go | 27 +++++++++++++ config/provisioner_enums.go | 40 +++++++++++++++++++ config/test-fixtures/provisioners-destroy.tf | 14 +++++++ .../validate-basic-provisioners/main.tf | 14 +++++++ 8 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 config/provisioner_enums.go create mode 100644 config/test-fixtures/provisioners-destroy.tf create mode 100644 config/test-fixtures/validate-basic-provisioners/main.tf diff --git a/config/config.go b/config/config.go index e1543a1af..724a78144 100644 --- a/config/config.go +++ b/config/config.go @@ -136,6 +136,9 @@ type Provisioner struct { Type string RawConfig *RawConfig ConnInfo *RawConfig + + When ProvisionerWhen + OnFailure ProvisionerOnFailure } // Copy returns a copy of this Provisioner @@ -144,6 +147,8 @@ func (p *Provisioner) Copy() *Provisioner { Type: p.Type, RawConfig: p.RawConfig.Copy(), ConnInfo: p.ConnInfo.Copy(), + When: p.When, + OnFailure: p.OnFailure, } } @@ -553,7 +558,7 @@ func (c *Config) Validate() error { // Validate DependsOn errs = append(errs, c.validateDependsOn(n, r.DependsOn, resources, modules)...) - // Verify provisioners don't contain any splats + // Verify provisioners for _, p := range r.Provisioners { // This validation checks that there are now splat variables // referencing ourself. This currently is not allowed. @@ -585,6 +590,17 @@ func (c *Config) Validate() error { break } } + + // Check for invalid when/onFailure values, though this should be + // picked up by the loader we check here just in case. + if p.When == ProvisionerWhenInvalid { + errs = append(errs, fmt.Errorf( + "%s: provisioner 'when' value is invalid", n)) + } + if p.OnFailure == ProvisionerOnFailureInvalid { + errs = append(errs, fmt.Errorf( + "%s: provisioner 'on_failure' value is invalid", n)) + } } // Verify ignore_changes contains valid entries diff --git a/config/config_string.go b/config/config_string.go index f11290e87..a5ef7d5cd 100644 --- a/config/config_string.go +++ b/config/config_string.go @@ -214,7 +214,16 @@ func resourcesStr(rs []*Resource) string { if len(r.Provisioners) > 0 { result += fmt.Sprintf(" provisioners\n") for _, p := range r.Provisioners { - result += fmt.Sprintf(" %s\n", p.Type) + when := "" + if p.When != ProvisionerWhenCreate { + when = fmt.Sprintf(" (%s)", p.When.String()) + } + + result += fmt.Sprintf(" %s%s\n", p.Type, when) + + if p.OnFailure != ProvisionerOnFailureFail { + result += fmt.Sprintf(" on_failure = %s\n", p.OnFailure.String()) + } ks := make([]string, 0, len(p.RawConfig.Raw)) for k, _ := range p.RawConfig.Raw { diff --git a/config/config_test.go b/config/config_test.go index c73ed6100..95acd28b7 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -168,6 +168,13 @@ func TestConfigValidate_table(t *testing.T) { true, "data sources cannot have", }, + + { + "basic provisioners", + "validate-basic-provisioners", + false, + "", + }, } for i, tc := range cases { diff --git a/config/loader_hcl.go b/config/loader_hcl.go index 88f6fd0cf..dfadce9d6 100644 --- a/config/loader_hcl.go +++ b/config/loader_hcl.go @@ -849,8 +849,40 @@ func loadProvisionersHcl(list *ast.ObjectList, connInfo map[string]interface{}) return nil, err } - // Delete the "connection" section, handle separately + // Parse the "when" value + when := ProvisionerWhenCreate + if v, ok := config["when"]; ok { + switch v { + case "create": + when = ProvisionerWhenCreate + case "destroy": + when = ProvisionerWhenDestroy + default: + return nil, fmt.Errorf( + "position %s: 'provisioner' when must be 'create' or 'destroy'", + item.Pos()) + } + } + + // Parse the "on_failure" value + onFailure := ProvisionerOnFailureFail + if v, ok := config["on_failure"]; ok { + switch v { + case "continue": + onFailure = ProvisionerOnFailureContinue + case "fail": + onFailure = ProvisionerOnFailureFail + default: + return nil, fmt.Errorf( + "position %s: 'provisioner' on_failure must be 'continue' or 'fail'", + item.Pos()) + } + } + + // Delete fields we special case delete(config, "connection") + delete(config, "when") + delete(config, "on_failure") rawConfig, err := NewRawConfig(config) if err != nil { @@ -889,6 +921,8 @@ func loadProvisionersHcl(list *ast.ObjectList, connInfo map[string]interface{}) Type: n, RawConfig: rawConfig, ConnInfo: connRaw, + When: when, + OnFailure: onFailure, }) } diff --git a/config/loader_test.go b/config/loader_test.go index 1cdc4e561..80c0d8d1f 100644 --- a/config/loader_test.go +++ b/config/loader_test.go @@ -629,6 +629,22 @@ func TestLoadFile_provisioners(t *testing.T) { } } +func TestLoadFile_provisionersDestroy(t *testing.T) { + c, err := LoadFile(filepath.Join(fixtureDir, "provisioners-destroy.tf")) + if err != nil { + t.Fatalf("err: %s", err) + } + + if c == nil { + t.Fatal("config should not be nil") + } + + actual := resourcesStr(c.Resources) + if actual != strings.TrimSpace(provisionerDestroyResourcesStr) { + t.Fatalf("bad:\n%s", actual) + } +} + func TestLoadFile_unnamedOutput(t *testing.T) { _, err := LoadFile(filepath.Join(fixtureDir, "output-unnamed.tf")) if err == nil { @@ -1126,6 +1142,17 @@ aws_instance.web (x1) user: var.foo ` +const provisionerDestroyResourcesStr = ` +aws_instance.web (x1) + provisioners + shell + shell (destroy) + path + shell (destroy) + on_failure = continue + path +` + const connectionResourcesStr = ` aws_instance.web (x1) ami diff --git a/config/provisioner_enums.go b/config/provisioner_enums.go new file mode 100644 index 000000000..00fd43fce --- /dev/null +++ b/config/provisioner_enums.go @@ -0,0 +1,40 @@ +package config + +// ProvisionerWhen is an enum for valid values for when to run provisioners. +type ProvisionerWhen int + +const ( + ProvisionerWhenInvalid ProvisionerWhen = iota + ProvisionerWhenCreate + ProvisionerWhenDestroy +) + +var provisionerWhenStrs = map[ProvisionerWhen]string{ + ProvisionerWhenInvalid: "invalid", + ProvisionerWhenCreate: "create", + ProvisionerWhenDestroy: "destroy", +} + +func (v ProvisionerWhen) String() string { + return provisionerWhenStrs[v] +} + +// ProvisionerOnFailure is an enum for valid values for on_failure options +// for provisioners. +type ProvisionerOnFailure int + +const ( + ProvisionerOnFailureInvalid ProvisionerOnFailure = iota + ProvisionerOnFailureContinue + ProvisionerOnFailureFail +) + +var provisionerOnFailureStrs = map[ProvisionerOnFailure]string{ + ProvisionerOnFailureInvalid: "invalid", + ProvisionerOnFailureContinue: "continue", + ProvisionerOnFailureFail: "fail", +} + +func (v ProvisionerOnFailure) String() string { + return provisionerOnFailureStrs[v] +} diff --git a/config/test-fixtures/provisioners-destroy.tf b/config/test-fixtures/provisioners-destroy.tf new file mode 100644 index 000000000..0ad4f557b --- /dev/null +++ b/config/test-fixtures/provisioners-destroy.tf @@ -0,0 +1,14 @@ +resource "aws_instance" "web" { + provisioner "shell" {} + + provisioner "shell" { + path = "foo" + when = "destroy" + } + + provisioner "shell" { + path = "foo" + when = "destroy" + on_failure = "continue" + } +} diff --git a/config/test-fixtures/validate-basic-provisioners/main.tf b/config/test-fixtures/validate-basic-provisioners/main.tf new file mode 100644 index 000000000..0ad4f557b --- /dev/null +++ b/config/test-fixtures/validate-basic-provisioners/main.tf @@ -0,0 +1,14 @@ +resource "aws_instance" "web" { + provisioner "shell" {} + + provisioner "shell" { + path = "foo" + when = "destroy" + } + + provisioner "shell" { + path = "foo" + when = "destroy" + on_failure = "continue" + } +}