diff --git a/depgraph/dependency.go b/depgraph/dependency.go new file mode 100644 index 000000000..0d84e9fbc --- /dev/null +++ b/depgraph/dependency.go @@ -0,0 +1,36 @@ +package depgraph + +import ( + "github.com/hashicorp/terraform/digraph" +) + +// Dependency is used to create a directed edge between two nouns. +// One noun may depend on another and provide version constraints +// that cannot be violated +type Dependency struct { + Name string + Meta interface{} + Constraints []Constraint + Source *Noun + Target *Noun +} + +// Constraint is used by dependencies to allow arbitrary constraints +// between nouns +type Constraint interface { + Satisfied(head, tail *Noun) (bool, error) +} + +// Head returns the source, or dependent noun +func (d *Dependency) Head() digraph.Node { + return d.Source +} + +// Tail returns the target, or depended upon noun +func (d *Dependency) Tail() digraph.Node { + return d.Target +} + +func (d *Dependency) String() string { + return d.Name +} diff --git a/depgraph/graph.go b/depgraph/graph.go new file mode 100644 index 000000000..1cf4c15b9 --- /dev/null +++ b/depgraph/graph.go @@ -0,0 +1,155 @@ +// The depgraph package is used to create and model a dependency graph +// of nouns. Each noun can represent a service, server, application, +// network switch, etc. Nouns can depend on other nouns, and provide +// versioning constraints. Nouns can also have various meta data that +// may be relevant to their construction or configuration. +package depgraph + +import ( + "fmt" + + "github.com/hashicorp/terraform/digraph" +) + +// Graph is used to represent the entire dependency graph +type Graph struct { + Name string + Meta interface{} + Nouns []*Noun + Root *Noun +} + +// Validate is used to ensure that a few properties of the graph are not violated: +// 1) There must be a single "root", or source on which nothing depends. +// 2) All nouns in the graph must be reachable from the root +// 3) The graph must be cycle free, meaning there are no cicular dependencies +func (g *Graph) Validate() error { + // Convert to node list + nodes := make([]digraph.Node, len(g.Nouns)) + for i, n := range g.Nouns { + nodes[i] = n + } + + // Create a validate erro + vErr := &ValidateError{} + + // Search for all the sources, if we have only 1, it must be the root + if sources := digraph.Sources(nodes); len(sources) != 1 { + vErr.MissingRoot = true + goto CHECK_CYCLES + } else { + g.Root = sources[0].(*Noun) + } + + // Check reachability + if unreached := digraph.Unreachable(g.Root, nodes); len(unreached) > 0 { + vErr.Unreachable = make([]*Noun, len(unreached)) + for i, u := range unreached { + vErr.Unreachable[i] = u.(*Noun) + } + } + +CHECK_CYCLES: + // Check for cycles + if cycles := digraph.StronglyConnectedComponents(nodes, true); len(cycles) > 0 { + vErr.Cycles = make([][]*Noun, len(cycles)) + for i, cycle := range cycles { + group := make([]*Noun, len(cycle)) + for j, n := range cycle { + group[j] = n.(*Noun) + } + vErr.Cycles[i] = group + } + } + + // Return the detailed error + if vErr.MissingRoot || vErr.Unreachable != nil || vErr.Cycles != nil { + return vErr + } + return nil +} + +// ValidateError implements the Error interface but provides +// additional information on a validation error +type ValidateError struct { + // If set, then the graph is missing a single root, on which + // there are no depdendencies + MissingRoot bool + + // Unreachable are nodes that could not be reached from + // the root noun. + Unreachable []*Noun + + // Cycles are groups of strongly connected nodes, which + // form a cycle. This is disallowed. + Cycles [][]*Noun +} + +func (v *ValidateError) Error() string { + return "The depedency graph is not valid" +} + +// CheckConstraints walks the graph and ensures that all +// user imposed constraints are satisfied. +func (g *Graph) CheckConstraints() error { + // Ensure we have a root + if g.Root == nil { + return fmt.Errorf("Graph must be validated before checking constraint violations") + } + + // Create a constraint error + cErr := &ConstraintError{} + + // Walk from the root + digraph.DepthFirstWalk(g.Root, func(n digraph.Node) bool { + noun := n.(*Noun) + for _, dep := range noun.Deps { + target := dep.Target + for _, constraint := range dep.Constraints { + ok, err := constraint.Satisfied(noun, target) + if ok { + continue + } + violation := &Violation{ + Source: noun, + Target: target, + Dependency: dep, + Constraint: constraint, + Err: err, + } + cErr.Violations = append(cErr.Violations, violation) + } + } + return true + }) + + if cErr.Violations != nil { + return cErr + } + return nil +} + +// ConstraintError is used to return detailed violation +// information from CheckConstraints +type ConstraintError struct { + Violations []*Violation +} + +func (c *ConstraintError) Error() string { + return fmt.Sprintf("%d constraint violations", len(c.Violations)) +} + +// Violation is used to pass along information about +// a constraint violation +type Violation struct { + Source *Noun + Target *Noun + Dependency *Dependency + Constraint Constraint + Err error +} + +func (v *Violation) Error() string { + return fmt.Sprintf("Constraint %v between %v and %v violated: %v", + v.Constraint, v.Source, v.Target, v.Err) +} diff --git a/depgraph/graph_test.go b/depgraph/graph_test.go new file mode 100644 index 000000000..f0bbc8889 --- /dev/null +++ b/depgraph/graph_test.go @@ -0,0 +1,277 @@ +package depgraph + +import ( + "fmt" + "strings" + "testing" +) + +// ParseNouns is used to parse a string in the format of: +// a -> b ; edge name +// b -> c +// Into a series of nouns and dependencies +func ParseNouns(s string) map[string]*Noun { + lines := strings.Split(s, "\n") + nodes := make(map[string]*Noun) + for _, line := range lines { + var edgeName string + if idx := strings.Index(line, ";"); idx >= 0 { + edgeName = strings.Trim(line[idx+1:], " \t\r\n") + line = line[:idx] + } + parts := strings.SplitN(line, "->", 2) + if len(parts) != 2 { + continue + } + head_name := strings.Trim(parts[0], " \t\r\n") + tail_name := strings.Trim(parts[1], " \t\r\n") + head := nodes[head_name] + if head == nil { + head = &Noun{Name: head_name} + nodes[head_name] = head + } + tail := nodes[tail_name] + if tail == nil { + tail = &Noun{Name: tail_name} + nodes[tail_name] = tail + } + edge := &Dependency{ + Name: edgeName, + Source: head, + Target: tail, + } + head.Deps = append(head.Deps, edge) + } + return nodes +} + +func NounMapToList(m map[string]*Noun) []*Noun { + list := make([]*Noun, 0, len(m)) + for _, n := range m { + list = append(list, n) + } + return list +} + +func TestGraph_Validate_NoRoot(t *testing.T) { + nodes := ParseNouns(`a -> b +b -> a`) + list := NounMapToList(nodes) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err == nil { + t.Fatalf("expected err") + } + + vErr, ok := err.(*ValidateError) + if !ok { + t.Fatalf("expected validate error") + } + + if !vErr.MissingRoot { + t.Fatalf("expected missing root") + } +} + +func TestGraph_Validate_MultiRoot(t *testing.T) { + nodes := ParseNouns(`a -> b +c -> d`) + list := NounMapToList(nodes) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err == nil { + t.Fatalf("expected err") + } + + vErr, ok := err.(*ValidateError) + if !ok { + t.Fatalf("expected validate error") + } + + if !vErr.MissingRoot { + t.Fatalf("expected missing root") + } +} + +func TestGraph_Validate_Unreachable(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +x -> x`) + list := NounMapToList(nodes) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err == nil { + t.Fatalf("expected err") + } + + vErr, ok := err.(*ValidateError) + if !ok { + t.Fatalf("expected validate error") + } + + if len(vErr.Unreachable) != 1 { + t.Fatalf("expected unreachable") + } + + if vErr.Unreachable[0].Name != "x" { + t.Fatalf("bad: %v", vErr.Unreachable[0]) + } +} + +func TestGraph_Validate_Cycle(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +d -> b`) + list := NounMapToList(nodes) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err == nil { + t.Fatalf("expected err") + } + + vErr, ok := err.(*ValidateError) + if !ok { + t.Fatalf("expected validate error") + } + + if len(vErr.Cycles) != 1 { + t.Fatalf("expected cycles") + } + + cycle := vErr.Cycles[0] + if cycle[0].Name != "d" { + t.Fatalf("bad: %v", cycle) + } + if cycle[1].Name != "b" { + t.Fatalf("bad: %v", cycle) + } +} + +func TestGraph_Validate(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +b -> e +c -> d +c -> e`) + list := NounMapToList(nodes) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err != nil { + t.Fatalf("err: %v", err) + } +} + +type VersionMeta int +type VersionConstraint struct { + Min int + Max int +} + +func (v *VersionConstraint) Satisfied(head, tail *Noun) (bool, error) { + vers := int(tail.Meta.(VersionMeta)) + if vers < v.Min { + return false, fmt.Errorf("version %d below minimum %d", + vers, v.Min) + } else if vers > v.Max { + return false, fmt.Errorf("version %d above maximum %d", + vers, v.Max) + } + return true, nil +} + +func (v *VersionConstraint) String() string { + return "version" +} + +func TestGraph_ConstraintViolation(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +b -> e +c -> d +c -> e`) + list := NounMapToList(nodes) + + // Add a version constraint + vers := &VersionConstraint{1, 3} + + // Introduce some constraints + depB := nodes["a"].Deps[0] + depB.Constraints = []Constraint{vers} + depC := nodes["a"].Deps[1] + depC.Constraints = []Constraint{vers} + + // Add some versions + nodes["b"].Meta = VersionMeta(0) + nodes["c"].Meta = VersionMeta(4) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err != nil { + t.Fatalf("err: %v", err) + } + + err = g.CheckConstraints() + if err == nil { + t.Fatalf("Expected err") + } + + cErr, ok := err.(*ConstraintError) + if !ok { + t.Fatalf("expected constraint error") + } + + if len(cErr.Violations) != 2 { + t.Fatalf("expected 2 violations: %v", cErr) + } + + if cErr.Violations[0].Error() != "Constraint version between a and b violated: version 0 below minimum 1" { + t.Fatalf("err: %v", cErr.Violations[0]) + } + + if cErr.Violations[1].Error() != "Constraint version between a and c violated: version 4 above maximum 3" { + t.Fatalf("err: %v", cErr.Violations[1]) + } +} + +func TestGraph_Constraint_NoViolation(t *testing.T) { + nodes := ParseNouns(`a -> b +a -> c +b -> d +b -> e +c -> d +c -> e`) + list := NounMapToList(nodes) + + // Add a version constraint + vers := &VersionConstraint{1, 3} + + // Introduce some constraints + depB := nodes["a"].Deps[0] + depB.Constraints = []Constraint{vers} + depC := nodes["a"].Deps[1] + depC.Constraints = []Constraint{vers} + + // Add some versions + nodes["b"].Meta = VersionMeta(2) + nodes["c"].Meta = VersionMeta(3) + + g := &Graph{Name: "Test", Nouns: list} + err := g.Validate() + if err != nil { + t.Fatalf("err: %v", err) + } + + err = g.CheckConstraints() + if err != nil { + t.Fatalf("err: %v", err) + } +} diff --git a/depgraph/noun.go b/depgraph/noun.go new file mode 100644 index 000000000..6f99746ec --- /dev/null +++ b/depgraph/noun.go @@ -0,0 +1,27 @@ +package depgraph + +import ( + "github.com/hashicorp/terraform/digraph" +) + +// Nouns are the key structure of the dependency graph. They can +// be used to represent all objects in the graph. They are linked +// by depedencies. +type Noun struct { + Name string // Opaque name + Meta interface{} + Deps []*Dependency +} + +// Edges returns the out-going edges of a Noun +func (n *Noun) Edges() []digraph.Edge { + edges := make([]digraph.Edge, len(n.Deps)) + for idx, dep := range n.Deps { + edges[idx] = dep + } + return edges +} + +func (n *Noun) String() string { + return n.Name +} diff --git a/digraph/basic.go b/digraph/basic.go new file mode 100644 index 000000000..8dc76838d --- /dev/null +++ b/digraph/basic.go @@ -0,0 +1,89 @@ +package digraph + +import ( + "fmt" + "strings" +) + +// BasicNode is a digraph Node that has a name and out edges +type BasicNode struct { + Name string + NodeEdges []Edge +} + +func (b *BasicNode) Edges() []Edge { + return b.NodeEdges +} + +func (b *BasicNode) AddEdge(edge Edge) { + b.NodeEdges = append(b.NodeEdges, edge) +} + +func (b *BasicNode) String() string { + if b.Name == "" { + return "Node" + } + return fmt.Sprintf("%v", b.Name) +} + +// BasicEdge is a digraph Edge that has a name, head and tail +type BasicEdge struct { + Name string + EdgeHead *BasicNode + EdgeTail *BasicNode +} + +func (b *BasicEdge) Head() Node { + return b.EdgeHead +} + +// Tail returns the end point of the Edge +func (b *BasicEdge) Tail() Node { + return b.EdgeTail +} + +func (b *BasicEdge) String() string { + if b.Name == "" { + return "Edge" + } + return fmt.Sprintf("%v", b.Name) +} + +// ParseBasic is used to parse a string in the format of: +// a -> b ; edge name +// b -> c +// Into a series of basic node and basic edges +func ParseBasic(s string) map[string]*BasicNode { + lines := strings.Split(s, "\n") + nodes := make(map[string]*BasicNode) + for _, line := range lines { + var edgeName string + if idx := strings.Index(line, ";"); idx >= 0 { + edgeName = strings.Trim(line[idx+1:], " \t\r\n") + line = line[:idx] + } + parts := strings.SplitN(line, "->", 2) + if len(parts) != 2 { + continue + } + head_name := strings.Trim(parts[0], " \t\r\n") + tail_name := strings.Trim(parts[1], " \t\r\n") + head := nodes[head_name] + if head == nil { + head = &BasicNode{Name: head_name} + nodes[head_name] = head + } + tail := nodes[tail_name] + if tail == nil { + tail = &BasicNode{Name: tail_name} + nodes[tail_name] = tail + } + edge := &BasicEdge{ + Name: edgeName, + EdgeHead: head, + EdgeTail: tail, + } + head.AddEdge(edge) + } + return nodes +} diff --git a/digraph/basic_test.go b/digraph/basic_test.go new file mode 100644 index 000000000..20584b09b --- /dev/null +++ b/digraph/basic_test.go @@ -0,0 +1,53 @@ +package digraph + +import ( + "fmt" + "testing" +) + +func TestParseBasic(t *testing.T) { + spec := `a -> b ; first +b -> c ; second +b -> d ; third +z -> a` + nodes := ParseBasic(spec) + if len(nodes) != 5 { + t.Fatalf("bad: %v", nodes) + } + + a := nodes["a"] + if a.Name != "a" { + t.Fatalf("bad: %v", a) + } + aEdges := a.Edges() + if len(aEdges) != 1 { + t.Fatalf("bad: %v", a.Edges()) + } + if fmt.Sprintf("%v", aEdges[0]) != "first" { + t.Fatalf("bad: %v", aEdges[0]) + } + + b := nodes["b"] + if len(b.Edges()) != 2 { + t.Fatalf("bad: %v", b.Edges()) + } + + c := nodes["c"] + if len(c.Edges()) != 0 { + t.Fatalf("bad: %v", c.Edges()) + } + + d := nodes["d"] + if len(d.Edges()) != 0 { + t.Fatalf("bad: %v", d.Edges()) + } + + z := nodes["z"] + zEdges := z.Edges() + if len(zEdges) != 1 { + t.Fatalf("bad: %v", z.Edges()) + } + if fmt.Sprintf("%v", zEdges[0]) != "Edge" { + t.Fatalf("bad: %v", zEdges[0]) + } +} diff --git a/digraph/digraph.go b/digraph/digraph.go new file mode 100644 index 000000000..ccf311170 --- /dev/null +++ b/digraph/digraph.go @@ -0,0 +1,34 @@ +package digraph + +// Digraph is used to represent a Directed Graph. This means +// we have a set of nodes, and a set of edges which are directed +// from a source and towards a destination +type Digraph interface { + // Nodes provides all the nodes in the graph + Nodes() []Node + + // Sources provides all the source nodes in the graph + Sources() []Node + + // Sinks provides all the sink nodes in the graph + Sinks() []Node + + // Transpose reverses the edge directions and returns + // a new Digraph + Transpose() Digraph +} + +// Node represents a vertex in a Digraph +type Node interface { + // Edges returns the out edges for a given nod + Edges() []Edge +} + +// Edge represents a directed edge in a Digraph +type Edge interface { + // Head returns the start point of the Edge + Head() Node + + // Tail returns the end point of the Edge + Tail() Node +} diff --git a/digraph/graphviz.go b/digraph/graphviz.go new file mode 100644 index 000000000..788112fec --- /dev/null +++ b/digraph/graphviz.go @@ -0,0 +1,22 @@ +package digraph + +import ( + "fmt" + "io" +) + +// GenerateDot is used to emit a GraphViz compatible definition +// for a directed graph. It can be used to dump a .dot file. +func GenerateDot(nodes []Node, w io.Writer) { + w.Write([]byte("digraph {\n")) + defer w.Write([]byte("}\n")) + for _, n := range nodes { + w.Write([]byte(fmt.Sprintf("\t%s;\n", n))) + for _, edge := range n.Edges() { + target := edge.Tail() + line := fmt.Sprintf("\t%s -> %s [label=\"%s\"];\n", + n, target, edge) + w.Write([]byte(line)) + } + } +} diff --git a/digraph/graphviz_test.go b/digraph/graphviz_test.go new file mode 100644 index 000000000..88939480f --- /dev/null +++ b/digraph/graphviz_test.go @@ -0,0 +1,57 @@ +package digraph + +import ( + "bytes" + "strings" + "testing" +) + +func Test_GenerateDot(t *testing.T) { + nodes := ParseBasic(`a -> b ; foo +a -> c +b -> d +b -> e +`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + buf := bytes.NewBuffer(nil) + GenerateDot(nlist, buf) + + out := string(buf.Bytes()) + if !strings.HasPrefix(out, "digraph {\n") { + t.Fatalf("bad: %v", out) + } + if !strings.HasSuffix(out, "\n}\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\ta;\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\tb;\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\tc;\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\td;\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\te;\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\ta -> b [label=\"foo\"];\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\ta -> c [label=\"Edge\"];\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\tb -> d [label=\"Edge\"];\n") { + t.Fatalf("bad: %v", out) + } + if !strings.Contains(out, "\n\tb -> e [label=\"Edge\"];\n") { + t.Fatalf("bad: %v", out) + } +} diff --git a/digraph/tarjan.go b/digraph/tarjan.go new file mode 100644 index 000000000..2298610ed --- /dev/null +++ b/digraph/tarjan.go @@ -0,0 +1,111 @@ +package digraph + +// sccAcct is used ot pass around accounting information for +// the StronglyConnectedComponents algorithm +type sccAcct struct { + ExcludeSingle bool + NextIndex int + NodeIndex map[Node]int + Stack []Node + SCC [][]Node +} + +// visit assigns an index and pushes a node onto the stack +func (s *sccAcct) visit(n Node) int { + idx := s.NextIndex + s.NodeIndex[n] = idx + s.NextIndex++ + s.push(n) + return idx +} + +// push adds a node to the stack +func (s *sccAcct) push(n Node) { + s.Stack = append(s.Stack, n) +} + +// pop removes a node from the stack +func (s *sccAcct) pop() Node { + n := len(s.Stack) + if n == 0 { + return nil + } + node := s.Stack[n-1] + s.Stack = s.Stack[:n-1] + return node +} + +// inStack checks if a node is in the stack +func (s *sccAcct) inStack(needle Node) bool { + for _, n := range s.Stack { + if n == needle { + return true + } + } + return false +} + +// StronglyConnectedComponents implements Tarjan's algorithm to +// find all the strongly connected components in a graph. This can +// be used to detected any cycles in a graph, as well as which nodes +// partipate in those cycles. excludeSingle is used to exclude strongly +// connected components of size one. +func StronglyConnectedComponents(nodes []Node, excludeSingle bool) [][]Node { + acct := sccAcct{ + ExcludeSingle: excludeSingle, + NextIndex: 1, + NodeIndex: make(map[Node]int, len(nodes)), + } + for _, node := range nodes { + // Recurse on any non-visited nodes + if acct.NodeIndex[node] == 0 { + stronglyConnected(&acct, node) + } + } + return acct.SCC +} + +func stronglyConnected(acct *sccAcct, node Node) int { + // Initial node visit + index := acct.visit(node) + minIdx := index + + for _, edge := range node.Edges() { + target := edge.Tail() + targetIdx := acct.NodeIndex[target] + + // Recurse on successor if not yet visited + if targetIdx == 0 { + minIdx = min(minIdx, stronglyConnected(acct, target)) + + } else if acct.inStack(target) { + // Check if the node is in the stack + minIdx = min(minIdx, targetIdx) + } + } + + // Pop the strongly connected components off the stack if + // this is a root node + if index == minIdx { + var scc []Node + for { + n := acct.pop() + scc = append(scc, n) + if n == node { + break + } + } + if !(acct.ExcludeSingle && len(scc) == 1) { + acct.SCC = append(acct.SCC, scc) + } + } + + return minIdx +} + +func min(a, b int) int { + if a <= b { + return a + } + return b +} diff --git a/digraph/tarjan_test.go b/digraph/tarjan_test.go new file mode 100644 index 000000000..10def8051 --- /dev/null +++ b/digraph/tarjan_test.go @@ -0,0 +1,75 @@ +package digraph + +import ( + "testing" +) + +func TestStronglyConnectedComponents(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +b -> c +c -> b +c -> d +d -> e`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + sccs := StronglyConnectedComponents(nlist, false) + if len(sccs) != 4 { + t.Fatalf("bad: %v", sccs) + } + + sccs = StronglyConnectedComponents(nlist, true) + if len(sccs) != 1 { + t.Fatalf("bad: %v", sccs) + } + + cycle := sccs[0] + if len(cycle) != 2 { + t.Fatalf("bad: %v", sccs) + } + + if cycle[0].(*BasicNode).Name != "c" { + t.Fatalf("bad: %v", cycle) + } + if cycle[1].(*BasicNode).Name != "b" { + t.Fatalf("bad: %v", cycle) + } +} + +func TestStronglyConnectedComponents2(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +b -> d +b -> e +c -> f +c -> g +g -> a +`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + sccs := StronglyConnectedComponents(nlist, true) + if len(sccs) != 1 { + t.Fatalf("bad: %v", sccs) + } + + cycle := sccs[0] + if len(cycle) != 3 { + t.Fatalf("bad: %v", sccs) + } + + if cycle[0].(*BasicNode).Name != "g" { + t.Fatalf("bad: %v", cycle) + } + if cycle[1].(*BasicNode).Name != "c" { + t.Fatalf("bad: %v", cycle) + } + if cycle[2].(*BasicNode).Name != "a" { + t.Fatalf("bad: %v", cycle) + } +} diff --git a/digraph/util.go b/digraph/util.go new file mode 100644 index 000000000..96a09ed82 --- /dev/null +++ b/digraph/util.go @@ -0,0 +1,113 @@ +package digraph + +// DepthFirstWalk performs a depth-first traversal of the nodes +// that can be reached from the initial input set. The callback is +// invoked for each visited node, and may return false to prevent +// vising any children of the current node +func DepthFirstWalk(node Node, cb func(n Node) bool) { + frontier := []Node{node} + seen := make(map[Node]struct{}) + for len(frontier) > 0 { + // Pop the current node + n := len(frontier) + current := frontier[n-1] + frontier = frontier[:n-1] + + // Check for potential cycle + if _, ok := seen[current]; ok { + continue + } + seen[current] = struct{}{} + + // Visit with the callback + if !cb(current) { + continue + } + + // Add any new edges to visit, in reverse order + edges := current.Edges() + for i := len(edges) - 1; i >= 0; i-- { + frontier = append(frontier, edges[i].Tail()) + } + } +} + +// FilterDegree returns only the nodes with the desired +// degree. This can be used with OutDegree or InDegree +func FilterDegree(degree int, degrees map[Node]int) []Node { + var matching []Node + for n, d := range degrees { + if d == degree { + matching = append(matching, n) + } + } + return matching +} + +// InDegree is used to compute the in-degree of nodes +func InDegree(nodes []Node) map[Node]int { + degree := make(map[Node]int, len(nodes)) + for _, n := range nodes { + if _, ok := degree[n]; !ok { + degree[n] = 0 + } + for _, e := range n.Edges() { + degree[e.Tail()]++ + } + } + return degree +} + +// OutDegree is used to compute the in-degree of nodes +func OutDegree(nodes []Node) map[Node]int { + degree := make(map[Node]int, len(nodes)) + for _, n := range nodes { + degree[n] = len(n.Edges()) + } + return degree +} + +// Sinks is used to get the nodes with out-degree of 0 +func Sinks(nodes []Node) []Node { + return FilterDegree(0, OutDegree(nodes)) +} + +// Sources is used to get the nodes with in-degree of 0 +func Sources(nodes []Node) []Node { + return FilterDegree(0, InDegree(nodes)) +} + +// Unreachable starts at a given start node, performs +// a DFS from there, and returns the set of unreachable nodes. +func Unreachable(start Node, nodes []Node) []Node { + // DFS from the start ndoe + frontier := []Node{start} + seen := make(map[Node]struct{}) + for len(frontier) > 0 { + // Pop the current node + n := len(frontier) + current := frontier[n-1] + frontier = frontier[:n-1] + + // Check for potential cycle + if _, ok := seen[current]; ok { + continue + } + seen[current] = struct{}{} + + // Add any new edges to visit, in reverse order + edges := current.Edges() + for i := len(edges) - 1; i >= 0; i-- { + frontier = append(frontier, edges[i].Tail()) + } + } + + // Check for any unseen nodes + var unseen []Node + for _, node := range nodes { + if _, ok := seen[node]; !ok { + unseen = append(unseen, node) + } + } + return unseen +} diff --git a/digraph/util_test.go b/digraph/util_test.go new file mode 100644 index 000000000..e6d359991 --- /dev/null +++ b/digraph/util_test.go @@ -0,0 +1,233 @@ +package digraph + +import ( + "reflect" + "testing" +) + +func TestDepthFirstWalk(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +d -> f +e -> a ; cycle`) + root := nodes["a"] + expected := []string{ + "a", + "b", + "e", + "c", + "d", + "f", + } + index := 0 + DepthFirstWalk(root, func(n Node) bool { + name := n.(*BasicNode).Name + if expected[index] != name { + t.Fatalf("expected: %v, got %v", expected[index], name) + } + index++ + return true + }) +} + +func TestInDegree(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +c -> e +d -> f`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + expected := map[string]int{ + "a": 0, + "b": 1, + "c": 1, + "d": 1, + "e": 2, + "f": 1, + } + indegree := InDegree(nlist) + for n, d := range indegree { + name := n.(*BasicNode).Name + exp := expected[name] + if exp != d { + t.Fatalf("Expected %d for %s, got %d", + exp, name, d) + } + } +} + +func TestOutDegree(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +c -> e +d -> f`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + expected := map[string]int{ + "a": 3, + "b": 1, + "c": 1, + "d": 1, + "e": 0, + "f": 0, + } + outDegree := OutDegree(nlist) + for n, d := range outDegree { + name := n.(*BasicNode).Name + exp := expected[name] + if exp != d { + t.Fatalf("Expected %d for %s, got %d", + exp, name, d) + } + } +} + +func TestSinks(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +c -> e +d -> f`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + sinks := Sinks(nlist) + + var haveE, haveF bool + for _, n := range sinks { + name := n.(*BasicNode).Name + switch name { + case "e": + haveE = true + case "f": + haveF = true + } + } + if !haveE || !haveF { + t.Fatalf("missing sink") + } +} + +func TestSources(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +c -> e +d -> f +x -> y`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + sources := Sources(nlist) + if len(sources) != 2 { + t.Fatalf("bad: %v", sources) + } + + var haveA, haveX bool + for _, n := range sources { + name := n.(*BasicNode).Name + switch name { + case "a": + haveA = true + case "x": + haveX = true + } + } + if !haveA || !haveX { + t.Fatalf("missing source %v %v", haveA, haveX) + } +} + +func TestUnreachable(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +c -> e +d -> f +f -> a +x -> y +y -> z`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + unreached := Unreachable(nodes["a"], nlist) + if len(unreached) != 3 { + t.Fatalf("bad: %v", unreached) + } + + var haveX, haveY, haveZ bool + for _, n := range unreached { + name := n.(*BasicNode).Name + switch name { + case "x": + haveX = true + case "y": + haveY = true + case "z": + haveZ = true + } + } + if !haveX || !haveY || !haveZ { + t.Fatalf("missing %v %v %v", haveX, haveY, haveZ) + } +} + +func TestUnreachable2(t *testing.T) { + nodes := ParseBasic(`a -> b +a -> c +a -> d +b -> e +c -> e +d -> f +f -> a +x -> y +y -> z`) + var nlist []Node + for _, n := range nodes { + nlist = append(nlist, n) + } + + unreached := Unreachable(nodes["x"], nlist) + if len(unreached) != 6 { + t.Fatalf("bad: %v", unreached) + } + + expected := map[string]struct{}{ + "a": struct{}{}, + "b": struct{}{}, + "c": struct{}{}, + "d": struct{}{}, + "e": struct{}{}, + "f": struct{}{}, + } + out := map[string]struct{}{} + for _, n := range unreached { + name := n.(*BasicNode).Name + out[name] = struct{}{} + } + + if !reflect.DeepEqual(out, expected) { + t.Fatalf("bad: %v %v", out, expected) + } +}