dag: TransitiveReduction

This commit is contained in:
Mitchell Hashimoto 2015-02-27 19:12:19 -08:00
parent abd68c2c87
commit ed2075e384
3 changed files with 145 additions and 0 deletions

View File

@ -40,6 +40,46 @@ func (g *AcyclicGraph) Root() (Vertex, error) {
return roots[0], nil
}
// TransitiveReduction performs the transitive reduction of graph g in place.
// The transitive reduction of a graph is a graph with as few edges as
// possible with the same reachability as the original graph. This means
// that if there are three nodes A => B => C, and A connects to both
// B and C, and B connects to C, then the transitive reduction is the
// same graph with only a single edge between A and B, and a single edge
// between B and C.
//
// The graph must be valid for this operation to behave properly. If
// Validate() returns an error, the behavior is undefined and the results
// will likely be unexpected.
//
// Complexity: O(V(V+E)), or asymptotically O(VE)
func (g *AcyclicGraph) TransitiveReduction() {
// Find the root-like objects to start a depthFirstWalk from.
frontier := make([]Vertex, 0, 5)
for _, v := range g.Vertices() {
if g.UpEdges(v).Len() == 0 {
frontier = append(frontier, v)
}
}
// Do a DFS
g.depthFirstWalk(frontier, func(v Vertex) error {
parents := g.UpEdges(v).List()
targets := g.DownEdges(v)
for _, rawParent := range parents {
parent := rawParent.(Vertex)
shared := g.DownEdges(parent).Intersection(targets)
for _, rawTarget := range shared.List() {
target := rawTarget.(Vertex)
g.RemoveEdge(BasicEdge(parent, target))
}
}
return nil
})
}
// Validate validates the DAG. A DAG is valid if it has a single root
// with no cycles.
func (g *AcyclicGraph) Validate() error {
@ -161,3 +201,37 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) error {
<-doneCh
return errs
}
// depthFirstWalk does a depth-first walk of the graph starting from
// the vertices in start. This is not exported now but it would make sense
// to export this publicly at some point.
func (g *AcyclicGraph) depthFirstWalk(start []Vertex, cb WalkFunc) error {
seen := make(map[Vertex]struct{})
frontier := make([]Vertex, len(start))
copy(frontier, start)
for len(frontier) > 0 {
// Pop the current vertex
n := len(frontier)
current := frontier[n-1]
frontier = frontier[:n-1]
// Check if we've seen this already and return...
if _, ok := seen[current]; ok {
continue
}
seen[current] = struct{}{}
// Visit the current node
if err := cb(current); err != nil {
return err
}
// Visit targets of this in reverse order.
targets := g.DownEdges(current).List()
for i := len(targets) - 1; i >= 0; i-- {
frontier = append(frontier, targets[i].(Vertex))
}
}
return nil
}

View File

@ -3,6 +3,7 @@ package dag
import (
"fmt"
"reflect"
"strings"
"sync"
"testing"
)
@ -48,6 +49,44 @@ func TestAcyclicGraphRoot_multiple(t *testing.T) {
}
}
func TestAyclicGraphTransReduction(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(2, 3))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestAyclicGraphTransReduction_more(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(1, 4))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(2, 4))
g.Connect(BasicEdge(3, 4))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionMoreStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestAcyclicGraphValidate(t *testing.T) {
var g AcyclicGraph
g.Add(1)
@ -156,3 +195,21 @@ func TestAcyclicGraphWalk_error(t *testing.T) {
t.Fatalf("bad: %#v", visits)
}
const testGraphTransReductionStr = `
1
2
2
3
3
`
const testGraphTransReductionMoreStr = `
1
2
2
3
3
4
4
`

View File

@ -36,6 +36,20 @@ func (s *Set) Include(v interface{}) bool {
return ok
}
// Intersection computes the set intersection with other.
func (s *Set) Intersection(other *Set) *Set {
result := new(Set)
if other != nil {
for _, v := range s.m {
if other.Include(v) {
result.Add(v)
}
}
}
return result
}
// Len is the number of items in the set.
func (s *Set) Len() int {
if s == nil {