diff --git a/dag/dag.go b/dag/dag.go index ed7d77e99..4af78448b 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -2,11 +2,8 @@ package dag import ( "fmt" - "log" "sort" "strings" - "sync" - "time" "github.com/hashicorp/go-multierror" ) @@ -169,94 +166,9 @@ func (g *AcyclicGraph) Cycles() [][]Vertex { func (g *AcyclicGraph) Walk(cb WalkFunc) error { defer g.debug.BeginOperation(typeWalk, "").End("") - // Cache the vertices since we use it multiple times - vertices := g.Vertices() - - // Build the waitgroup that signals when we're done - var wg sync.WaitGroup - wg.Add(len(vertices)) - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) - wg.Wait() - }() - - // The map of channels to watch to wait for vertices to finish - vertMap := make(map[Vertex]chan struct{}) - for _, v := range vertices { - vertMap[v] = make(chan struct{}) - } - - // The map of whether a vertex errored or not during the walk - var errLock sync.Mutex - var errs error - errMap := make(map[Vertex]bool) - for _, v := range vertices { - // Build our list of dependencies and the list of channels to - // wait on until we start executing for this vertex. - deps := AsVertexList(g.DownEdges(v)) - depChs := make([]<-chan struct{}, len(deps)) - for i, dep := range deps { - depChs[i] = vertMap[dep] - } - - // Get our channel so that we can close it when we're done - ourCh := vertMap[v] - - // Start the goroutine to wait for our dependencies - readyCh := make(chan bool) - go func(v Vertex, deps []Vertex, chs []<-chan struct{}, readyCh chan<- bool) { - // First wait for all the dependencies - for i, ch := range chs { - DepSatisfied: - for { - select { - case <-ch: - break DepSatisfied - case <-time.After(time.Second * 5): - log.Printf("[DEBUG] vertex %q, waiting for: %q", - VertexName(v), VertexName(deps[i])) - } - } - log.Printf("[DEBUG] vertex %q, got dep: %q", - VertexName(v), VertexName(deps[i])) - } - - // Then, check the map to see if any of our dependencies failed - errLock.Lock() - defer errLock.Unlock() - for _, dep := range deps { - if errMap[dep] { - errMap[v] = true - readyCh <- false - return - } - } - - readyCh <- true - }(v, deps, depChs, readyCh) - - // Start the goroutine that executes - go func(v Vertex, doneCh chan<- struct{}, readyCh <-chan bool) { - defer close(doneCh) - defer wg.Done() - - var err error - if ready := <-readyCh; ready { - err = cb(v) - } - - errLock.Lock() - defer errLock.Unlock() - if err != nil { - errMap[v] = true - errs = multierror.Append(errs, err) - } - }(v, ourCh, readyCh) - } - - <-doneCh - return errs + w := &walker{Callback: cb, Reverse: true} + w.Update(g.vertices, g.edges) + return w.Wait() } // simple convenience helper for converting a dag.Set to a []Vertex diff --git a/dag/walk.go b/dag/walk.go index eb49b0c41..30dd207a6 100644 --- a/dag/walk.go +++ b/dag/walk.go @@ -1,6 +1,7 @@ package dag import ( + "errors" "fmt" "log" "sync" @@ -12,6 +13,9 @@ import ( // walker performs a graph walk and supports walk-time changing of vertices // and edges. // +// The walk is depth first by default. This can be changed with the Reverse +// option. +// // A single walker is only valid for one graph walk. After the walk is complete // you must construct a new walker to walk again. State for the walk is never // deleted in case vertices or edges are changed. @@ -19,6 +23,10 @@ type walker struct { // Callback is what is called for each vertex Callback WalkFunc + // Reverse, if true, causes the source of an edge to depend on a target. + // When false (default), the target depends on the source. + Reverse bool + // changeLock must be held to modify any of the fields below. Only Update // should modify these fields. Modifying them outside of Update can cause // serious problems. @@ -44,7 +52,7 @@ type walkerVertex struct { // Dependency information. Any changes to any of these fields requires // holding DepsLock. - DepsCh chan struct{} + DepsCh chan bool DepsUpdateCh chan struct{} DepsLock sync.Mutex @@ -54,6 +62,11 @@ type walkerVertex struct { depsCancelCh chan struct{} } +// errWalkUpstream is used in the errMap of a walk to note that an upstream +// dependency failed so this vertex wasn't run. This is not shown in the final +// user-returned error. +var errWalkUpstream = errors.New("upstream dependency failed") + // Wait waits for the completion of the walk and returns any errors ( // in the form of a multierror) that occurred. Update should be called // to populate the walk with vertices and edges prior to calling this. @@ -72,8 +85,10 @@ func (w *walker) Wait() error { // Build the error var result error for v, err := range w.errMap { - result = multierror.Append(result, fmt.Errorf( - "%s: %s", VertexName(v), err)) + if err != nil && err != errWalkUpstream { + result = multierror.Append(result, fmt.Errorf( + "%s: %s", VertexName(v), err)) + } } return result @@ -116,12 +131,12 @@ func (w *walker) Update(v, e *Set) { info := &walkerVertex{ DoneCh: make(chan struct{}), CancelCh: make(chan struct{}), - DepsCh: make(chan struct{}), + DepsCh: make(chan bool, 1), deps: make(map[Vertex]chan struct{}), } - // Close the deps channel immediately so it passes - close(info.DepsCh) + // Pass dependencies immediately assuming we have no edges + info.DepsCh <- true // Add it to the map and kick off the walk w.vertexMap[v] = info @@ -153,12 +168,7 @@ func (w *walker) Update(v, e *Set) { var changedDeps Set for _, raw := range newEdges.List() { edge := raw.(Edge) - - // waiter is the vertex that is "waiting" on this edge - waiter := edge.Target() - - // dep is the dependency we're waiting on - dep := edge.Source() + waiter, dep := w.edgeParts(edge) // Get the info for the waiter waiterInfo, ok := w.vertexMap[waiter] @@ -189,12 +199,7 @@ func (w *walker) Update(v, e *Set) { // Process reoved edges for _, raw := range oldEdges.List() { edge := raw.(Edge) - - // waiter is the vertex that is "waiting" on this edge - waiter := edge.Target() - - // dep is the dependency we're waiting on - dep := edge.Source() + waiter, dep := w.edgeParts(edge) // Get the info for the waiter waiterInfo, ok := w.vertexMap[waiter] @@ -226,7 +231,7 @@ func (w *walker) Update(v, e *Set) { } // Create a new done channel - doneCh := make(chan struct{}) + doneCh := make(chan bool, 1) // Create the channel we close for cancellation cancelCh := make(chan struct{}) @@ -252,6 +257,10 @@ func (w *walker) Update(v, e *Set) { } info.depsCancelCh = cancelCh + log.Printf( + "[DEBUG] dag/walk: dependencies changed for %q, sending new deps", + VertexName(v)) + // Start the waiter go w.waitDeps(v, deps, doneCh, cancelCh) } @@ -264,6 +273,16 @@ func (w *walker) Update(v, e *Set) { } } +// edgeParts returns the waiter and the dependency, in that order. +// The waiter is waiting on the dependency. +func (w *walker) edgeParts(e Edge) (Vertex, Vertex) { + if w.Reverse { + return e.Source(), e.Target() + } + + return e.Target(), e.Source() +} + // walkVertex walks a single vertex, waiting for any dependencies before // executing the callback. func (w *walker) walkVertex(v Vertex, info *walkerVertex) { @@ -273,16 +292,20 @@ func (w *walker) walkVertex(v Vertex, info *walkerVertex) { // When we're done, always close our done channel defer close(info.DoneCh) - // Wait for our dependencies - depsCh := info.DepsCh + // Wait for our dependencies. We create a [closed] deps channel so + // that we can immediately fall through to load our actual DepsCh. + var depsSuccess bool + depsCh := make(chan bool, 1) + depsCh <- true + close(depsCh) for { select { case <-info.CancelCh: // Cancel return - case <-depsCh: - // Deps complete! + case depsSuccess = <-depsCh: + // Deps complete! Mark as nil to trigger completion handling. depsCh = nil case <-info.DepsUpdateCh: @@ -306,9 +329,27 @@ func (w *walker) walkVertex(v Vertex, info *walkerVertex) { } } - // Call our callback - log.Printf("[DEBUG] dag/walk: walking %q", VertexName(v)) - if err := w.Callback(v); err != nil { + // If we passed dependencies, we just want to check once more that + // we're not cancelled, since this can happen just as dependencies pass. + select { + case <-info.CancelCh: + // Cancelled during an update while dependencies completed. + return + default: + } + + // Run our callback or note that our upstream failed + var err error + if depsSuccess { + log.Printf("[DEBUG] dag/walk: walking %q", VertexName(v)) + err = w.Callback(v) + } else { + log.Printf("[DEBUG] dag/walk: upstream errored, not walking %q", VertexName(v)) + err = errWalkUpstream + } + + // Record the error + if err != nil { w.errLock.Lock() defer w.errLock.Unlock() @@ -322,11 +363,8 @@ func (w *walker) walkVertex(v Vertex, info *walkerVertex) { func (w *walker) waitDeps( v Vertex, deps map[Vertex]<-chan struct{}, - doneCh chan<- struct{}, + doneCh chan<- bool, cancelCh <-chan struct{}) { - // Whenever we return, mark ourselves as complete - defer close(doneCh) - // For each dependency given to us, wait for it to complete for dep, depCh := range deps { DepSatisfied: @@ -337,13 +375,29 @@ func (w *walker) waitDeps( break DepSatisfied case <-cancelCh: - // Wait cancelled + // Wait cancelled. Note that we didn't satisfy dependencies + // so that anything waiting on us also doesn't run. + doneCh <- false return case <-time.After(time.Second * 5): - log.Printf("[DEBUG] vertex %q, waiting for: %q", + log.Printf("[DEBUG] dag/walk: vertex %q, waiting for: %q", VertexName(v), VertexName(dep)) } } } + + // Dependencies satisfied! We need to check if any errored + w.errLock.Lock() + defer w.errLock.Unlock() + for dep, _ := range deps { + if w.errMap[dep] != nil { + // One of our dependencies failed, so return false + doneCh <- false + return + } + } + + // All dependencies satisfied and successful + doneCh <- true } diff --git a/dag/walk_test.go b/dag/walk_test.go index c720e8a54..9927eb48e 100644 --- a/dag/walk_test.go +++ b/dag/walk_test.go @@ -33,6 +33,44 @@ func TestWalker_basic(t *testing.T) { } } +func TestWalker_error(t *testing.T) { + var g Graph + g.Add(1) + g.Add(2) + g.Add(3) + g.Add(4) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 3)) + g.Connect(BasicEdge(3, 4)) + + // Record function + var order []interface{} + recordF := walkCbRecord(&order) + + // Build a callback that delays until we close a channel + cb := func(v Vertex) error { + if v == 2 { + return fmt.Errorf("error!") + } + + return recordF(v) + } + + w := &walker{Callback: cb} + w.Update(g.vertices, g.edges) + + // Wait + if err := w.Wait(); err == nil { + t.Fatal("expect error") + } + + // Check + expected := []interface{}{1} + if !reflect.DeepEqual(order, expected) { + t.Fatalf("bad: %#v", order) + } +} + func TestWalker_newVertex(t *testing.T) { // Run it a bunch of times since it is timing dependent for i := 0; i < 50; i++ { @@ -82,26 +120,20 @@ func TestWalker_removeVertex(t *testing.T) { recordF := walkCbRecord(&order) // Build a callback that delays until we close a channel - gateCh := make(chan struct{}) + var w *walker cb := func(v Vertex) error { if v == 1 { - <-gateCh + g.Remove(2) + w.Update(g.vertices, g.edges) } return recordF(v) } // Add the initial vertices - w := &walker{Callback: cb} + w = &walker{Callback: cb} w.Update(g.vertices, g.edges) - // Remove a vertex - g.Remove(2) - w.Update(g.vertices, g.edges) - - // Open gate - close(gateCh) - // Wait if err := w.Wait(); err != nil { t.Fatalf("err: %s", err)