From 0d9fb53a5ac57a1ba7d0e448aa7d8696dbaa89d8 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Thu, 5 Jun 2014 02:11:28 -0700 Subject: [PATCH] depgraph: add Walk function --- depgraph/graph.go | 88 ++++++++++++++++++++++++++++++++++++++++++ depgraph/graph_test.go | 66 +++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) diff --git a/depgraph/graph.go b/depgraph/graph.go index fd0bb4220..9f719bebd 100644 --- a/depgraph/graph.go +++ b/depgraph/graph.go @@ -12,6 +12,9 @@ import ( "github.com/hashicorp/terraform/digraph" ) +// WalkFunc is the type used for the callback for Walk. +type WalkFunc func(*Noun) error + // Graph is used to represent a dependency graph. type Graph struct { Name string @@ -180,3 +183,88 @@ CHECK_CYCLES: } return nil } + +// Walk will walk the tree depth-first (dependency first) and call +// the callback. +// +// The callbacks will be called in parallel, so if you need non-parallelism, +// then introduce a lock in your callback. +func (g *Graph) Walk(fn WalkFunc) error { + // Set so we don't callback for a single noun multiple times + seenMap := make(map[*Noun]chan struct{}) + seenMap[g.Root] = make(chan struct{}) + + // Build the list of things to visit + tovisit := make([]*Noun, 1, len(g.Nouns)) + tovisit[0] = g.Root + + // Spawn off all our goroutines to walk the tree + errCh := make(chan error) + quitCh := make(chan struct{}) + for len(tovisit) > 0 { + // Grab the current thing to use + n := len(tovisit) + current := tovisit[n-1] + tovisit = tovisit[:n-1] + + // Go through each dependency and run that first + for _, dep := range current.Deps { + if _, ok := seenMap[dep.Target]; !ok { + seenMap[dep.Target] = make(chan struct{}) + tovisit = append(tovisit, dep.Target) + } + } + + // Spawn off a goroutine to execute our callback once + // all our dependencies are satisified. + go func(current *Noun) { + defer close(seenMap[current]) + + // Wait for all our dependencies + for _, dep := range current.Deps { + select { + case <-seenMap[dep.Target]: + case <-quitCh: + return + } + } + + // Call our callback! + if err := fn(current); err != nil { + errCh <- err + } + }(current) + } + + // Aggregate channel that is closed when all goroutines finish + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + + for _, ch := range seenMap { + <-ch + } + }() + + // Wait for finish OR an error + select { + case <-doneCh: + return nil + case err := <-errCh: + // Close the quit channel so all our goroutines will end now + close(quitCh) + + // Drain the error channel + go func() { + for _ = range errCh { + // Nothing + } + }() + + // Wait for the goroutines to end + <-doneCh + close(errCh) + + return err + } +} diff --git a/depgraph/graph_test.go b/depgraph/graph_test.go index 92b7cc5f7..d7dbb6f8e 100644 --- a/depgraph/graph_test.go +++ b/depgraph/graph_test.go @@ -2,7 +2,9 @@ package depgraph import ( "fmt" + "reflect" "strings" + "sync" "testing" ) @@ -313,3 +315,67 @@ c -> e`) t.Fatalf("err: %v", err) } } + +func TestGraphWalk(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +b -> e +c -> d +c -> e`) + list := NounMapToList(nodes) + g := &Graph{Name: "Test", Nouns: list} + if err := g.Validate(); err != nil { + t.Fatalf("err: %s", err) + } + + var namesLock sync.Mutex + names := make([]string, 0, 0) + err := g.Walk(func(n *Noun) error { + namesLock.Lock() + defer namesLock.Unlock() + names = append(names, n.Name) + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := [][]string{ + {"e", "d", "c", "b", "a"}, + {"e", "d", "b", "c", "a"}, + {"d", "e", "c", "b", "a"}, + {"d", "e", "b", "c", "a"}, + } + found := false + for _, expect := range expected { + if reflect.DeepEqual(expect, names) { + found = true + break + } + } + if !found { + t.Fatalf("bad: %#v", names) + } +} + +func TestGraphWalk_error(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +b -> e +c -> d +c -> e`) + list := NounMapToList(nodes) + g := &Graph{Name: "Test", Nouns: list} + if err := g.Validate(); err != nil { + t.Fatalf("err: %s", err) + } + + err := g.Walk(func(n *Noun) error { + return fmt.Errorf("foo") + }) + if err == nil { + t.Fatal("should error") + } +}