diff --git a/dag/dag.go b/dag/dag.go index e41de6053..c53ec284a 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -1,14 +1,41 @@ package dag +import ( + "fmt" +) + // AcyclicGraph is a specialization of Graph that cannot have cycles. With // this property, we get the property of sane graph traversal. type AcyclicGraph struct { - *Graph + Graph } // WalkFunc is the callback used for walking the graph. type WalkFunc func(Vertex) +// Root returns the root of the DAG, or an error. +// +// Complexity: O(V) +func (g *AcyclicGraph) Root() (Vertex, error) { + roots := make([]Vertex, 0, 1) + for _, v := range g.Vertices() { + if g.UpEdges(v).Len() == 0 { + roots = append(roots, v) + } + } + + if len(roots) > 1 { + // TODO(mitchellh): make this error message a lot better + return nil, fmt.Errorf("multiple roots: %#v", roots) + } + + if len(roots) == 0 { + return nil, fmt.Errorf("no roots found") + } + + return roots[0], nil +} + // Walk walks the graph, calling your callback as each node is visited. func (g *AcyclicGraph) Walk(cb WalkFunc) { } diff --git a/dag/dag_test.go b/dag/dag_test.go new file mode 100644 index 000000000..94f25d6e9 --- /dev/null +++ b/dag/dag_test.go @@ -0,0 +1,46 @@ +package dag + +import ( + "testing" +) + +func TestAcyclicGraphRoot(t *testing.T) { + var g AcyclicGraph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(3, 2)) + g.Connect(BasicEdge(3, 1)) + + if root, err := g.Root(); err != nil { + t.Fatalf("err: %s", err) + } else if root != 3 { + t.Fatalf("bad: %#v", root) + } +} + +func TestAcyclicGraphRoot_cycle(t *testing.T) { + var g AcyclicGraph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 3)) + g.Connect(BasicEdge(3, 1)) + + if _, err := g.Root(); err == nil { + t.Fatal("should error") + } +} + +func TestAcyclicGraphRoot_multiple(t *testing.T) { + var g AcyclicGraph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(3, 2)) + + if _, err := g.Root(); err == nil { + t.Fatal("should error") + } +}