depgraph: add Walk function

This commit is contained in:
Mitchell Hashimoto 2014-06-05 02:11:28 -07:00
parent d731d033f1
commit 0d9fb53a5a
2 changed files with 154 additions and 0 deletions

View File

@ -12,6 +12,9 @@ import (
"github.com/hashicorp/terraform/digraph"
)
// WalkFunc is the type used for the callback for Walk.
type WalkFunc func(*Noun) error
// Graph is used to represent a dependency graph.
type Graph struct {
Name string
@ -180,3 +183,88 @@ CHECK_CYCLES:
}
return nil
}
// Walk will walk the tree depth-first (dependency first) and call
// the callback.
//
// The callbacks will be called in parallel, so if you need non-parallelism,
// then introduce a lock in your callback.
func (g *Graph) Walk(fn WalkFunc) error {
// Set so we don't callback for a single noun multiple times
seenMap := make(map[*Noun]chan struct{})
seenMap[g.Root] = make(chan struct{})
// Build the list of things to visit
tovisit := make([]*Noun, 1, len(g.Nouns))
tovisit[0] = g.Root
// Spawn off all our goroutines to walk the tree
errCh := make(chan error)
quitCh := make(chan struct{})
for len(tovisit) > 0 {
// Grab the current thing to use
n := len(tovisit)
current := tovisit[n-1]
tovisit = tovisit[:n-1]
// Go through each dependency and run that first
for _, dep := range current.Deps {
if _, ok := seenMap[dep.Target]; !ok {
seenMap[dep.Target] = make(chan struct{})
tovisit = append(tovisit, dep.Target)
}
}
// Spawn off a goroutine to execute our callback once
// all our dependencies are satisified.
go func(current *Noun) {
defer close(seenMap[current])
// Wait for all our dependencies
for _, dep := range current.Deps {
select {
case <-seenMap[dep.Target]:
case <-quitCh:
return
}
}
// Call our callback!
if err := fn(current); err != nil {
errCh <- err
}
}(current)
}
// Aggregate channel that is closed when all goroutines finish
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
for _, ch := range seenMap {
<-ch
}
}()
// Wait for finish OR an error
select {
case <-doneCh:
return nil
case err := <-errCh:
// Close the quit channel so all our goroutines will end now
close(quitCh)
// Drain the error channel
go func() {
for _ = range errCh {
// Nothing
}
}()
// Wait for the goroutines to end
<-doneCh
close(errCh)
return err
}
}

View File

@ -2,7 +2,9 @@ package depgraph
import (
"fmt"
"reflect"
"strings"
"sync"
"testing"
)
@ -313,3 +315,67 @@ c -> e`)
t.Fatalf("err: %v", err)
}
}
func TestGraphWalk(t *testing.T) {
nodes := ParseNouns(`a -> b
a -> c
b -> d
b -> e
c -> d
c -> e`)
list := NounMapToList(nodes)
g := &Graph{Name: "Test", Nouns: list}
if err := g.Validate(); err != nil {
t.Fatalf("err: %s", err)
}
var namesLock sync.Mutex
names := make([]string, 0, 0)
err := g.Walk(func(n *Noun) error {
namesLock.Lock()
defer namesLock.Unlock()
names = append(names, n.Name)
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
expected := [][]string{
{"e", "d", "c", "b", "a"},
{"e", "d", "b", "c", "a"},
{"d", "e", "c", "b", "a"},
{"d", "e", "b", "c", "a"},
}
found := false
for _, expect := range expected {
if reflect.DeepEqual(expect, names) {
found = true
break
}
}
if !found {
t.Fatalf("bad: %#v", names)
}
}
func TestGraphWalk_error(t *testing.T) {
nodes := ParseNouns(`a -> b
a -> c
b -> d
b -> e
c -> d
c -> e`)
list := NounMapToList(nodes)
g := &Graph{Name: "Test", Nouns: list}
if err := g.Validate(); err != nil {
t.Fatalf("err: %s", err)
}
err := g.Walk(func(n *Noun) error {
return fmt.Errorf("foo")
})
if err == nil {
t.Fatal("should error")
}
}