From 01ec680019eeb22f73f32d4628e5980060126705 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 27 Jan 2015 21:48:46 -0800 Subject: [PATCH] terraform: ProviderTransform to map resources to providers by dep --- dag/graph.go | 7 ++- dag/tarjan_test.go | 2 +- terraform/graph_config_node.go | 10 ++++ .../transform-provider-basic/main.tf | 2 + terraform/transform_provider.go | 55 +++++++++++++++++++ terraform/transform_provider_test.go | 35 ++++++++++++ terraform/util.go | 14 +++++ 7 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 terraform/test-fixtures/transform-provider-basic/main.tf create mode 100644 terraform/transform_provider.go create mode 100644 terraform/transform_provider_test.go diff --git a/dag/graph.go b/dag/graph.go index d3a79a7be..2e26f1ff9 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -88,7 +88,7 @@ func (g *Graph) String() string { names := make([]string, 0, len(g.vertices)) mapping := make(map[string]Vertex, len(g.vertices)) for _, v := range g.vertices { - name := vertName(v) + name := VertexName(v) names = append(names, name) mapping[name] = v } @@ -104,7 +104,7 @@ func (g *Graph) String() string { // Alphabetize dependencies deps := make([]string, 0, targets.Len()) for _, target := range targets.List() { - deps = append(deps, vertName(target)) + deps = append(deps, VertexName(target)) } sort.Strings(deps) @@ -124,7 +124,8 @@ func (g *Graph) init() { g.upEdges = make(map[Vertex]*set) } -func vertName(raw Vertex) string { +// VertexName returns the name of a vertex. +func VertexName(raw Vertex) string { switch v := raw.(type) { case NamedVertex: return v.Name() diff --git a/dag/tarjan_test.go b/dag/tarjan_test.go index b5dcd3f0f..9f749dabc 100644 --- a/dag/tarjan_test.go +++ b/dag/tarjan_test.go @@ -62,7 +62,7 @@ func testSCCStr(list [][]Vertex) string { for _, vs := range list { result := make([]string, len(vs)) for i, v := range vs { - result[i] = vertName(v) + result[i] = VertexName(v) } buf.WriteString(fmt.Sprintf("%s\n", strings.Join(result, ","))) diff --git a/terraform/graph_config_node.go b/terraform/graph_config_node.go index 28868b803..c586245c3 100644 --- a/terraform/graph_config_node.go +++ b/terraform/graph_config_node.go @@ -73,6 +73,11 @@ func (n *GraphNodeConfigProvider) DependentOn() []string { return result } +// GraphNodeProvider implementation +func (n *GraphNodeConfigProvider) ProviderName() string { + return n.Provider.Name +} + // GraphNodeConfigResource represents a resource within the config graph. type GraphNodeConfigResource struct { Resource *config.Resource @@ -105,3 +110,8 @@ func (n *GraphNodeConfigResource) DependentOn() []string { func (n *GraphNodeConfigResource) Name() string { return n.Resource.Id() } + +// GraphNodeProviderConsumer +func (n *GraphNodeConfigResource) ProvidedBy() string { + return resourceProvider(n.Resource.Type) +} diff --git a/terraform/test-fixtures/transform-provider-basic/main.tf b/terraform/test-fixtures/transform-provider-basic/main.tf new file mode 100644 index 000000000..8a44e1dcb --- /dev/null +++ b/terraform/test-fixtures/transform-provider-basic/main.tf @@ -0,0 +1,2 @@ +provider "aws" {} +resource "aws_instance" "web" {} diff --git a/terraform/transform_provider.go b/terraform/transform_provider.go new file mode 100644 index 000000000..1658e396d --- /dev/null +++ b/terraform/transform_provider.go @@ -0,0 +1,55 @@ +package terraform + +import ( + "fmt" + + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/terraform/dag" +) + +// ProviderTransformer is a GraphTransformer that maps resources to +// providers within the graph. This will error if there are any resources +// that don't map to proper resources. +type ProviderTransformer struct{} + +func (t *ProviderTransformer) Transform(g *Graph) error { + // First, build a map of the providers + m := make(map[string]dag.Vertex) + for _, v := range g.Vertices() { + if pv, ok := v.(GraphNodeProvider); ok { + m[pv.ProviderName()] = v + } + } + + // Go through the other nodes and match them to providers they need + var err error + for _, v := range g.Vertices() { + if pv, ok := v.(GraphNodeProviderConsumer); ok { + target := m[pv.ProvidedBy()] + if target == nil { + err = multierror.Append(err, fmt.Errorf( + "%s: provider %s couldn't be found", + dag.VertexName(v), pv.ProvidedBy())) + continue + } + + g.Connect(dag.BasicEdge(v, target)) + } + } + + return err +} + +// GraphNodeProvider is an interface that nodes that can be a provider +// must implement. The ProviderName returned is the name of the provider +// they satisfy. +type GraphNodeProvider interface { + ProviderName() string +} + +// GraphNodeProviderConsumer is an interface that nodes that require +// a provider must implement. ProvidedBy must return the name of the provider +// to use. +type GraphNodeProviderConsumer interface { + ProvidedBy() string +} diff --git a/terraform/transform_provider_test.go b/terraform/transform_provider_test.go new file mode 100644 index 000000000..254d4a38d --- /dev/null +++ b/terraform/transform_provider_test.go @@ -0,0 +1,35 @@ +package terraform + +import ( + "strings" + "testing" +) + +func TestProviderTransformer(t *testing.T) { + mod := testModule(t, "transform-provider-basic") + + g := Graph{Path: RootModulePath} + { + tf := &ConfigTransformer{Module: mod} + if err := tf.Transform(&g); err != nil { + t.Fatalf("err: %s", err) + } + } + + transform := &ProviderTransformer{} + if err := transform.Transform(&g); err != nil { + t.Fatalf("err: %s", err) + } + + actual := strings.TrimSpace(g.String()) + expected := strings.TrimSpace(testTransformProviderBasicStr) + if actual != expected { + t.Fatalf("bad:\n\n%s", actual) + } +} + +const testTransformProviderBasicStr = ` +aws_instance.web + provider.aws +provider.aws +` diff --git a/terraform/util.go b/terraform/util.go index ce7198d5b..7c8e0aaa2 100644 --- a/terraform/util.go +++ b/terraform/util.go @@ -1,5 +1,9 @@ package terraform +import ( + "strings" +) + // Semaphore is a wrapper around a channel to provide // utility methods to clarify that we are treating the // channel as a semaphore @@ -42,6 +46,16 @@ func (s Semaphore) Release() { } } +// resourceProvider returns the provider name for the given type. +func resourceProvider(t string) string { + idx := strings.IndexRune(t, '_') + if idx == -1 { + return "" + } + + return t[:idx] +} + // strSliceContains checks if a given string is contained in a slice // When anybody asks why Go needs generics, here you go. func strSliceContains(haystack []string, needle string) bool {