Merge pull request #1081 from hashicorp/f-trans-reduct

Perform Transitive Reduction on Graph
This commit is contained in:
Mitchell Hashimoto 2015-03-02 08:33:03 -08:00
commit 58cc129caa
8 changed files with 216 additions and 4 deletions

View File

@ -40,6 +40,44 @@ func (g *AcyclicGraph) Root() (Vertex, error) {
return roots[0], nil
}
// TransitiveReduction performs the transitive reduction of graph g in place.
// The transitive reduction of a graph is a graph with as few edges as
// possible with the same reachability as the original graph. This means
// that if there are three nodes A => B => C, and A connects to both
// B and C, and B connects to C, then the transitive reduction is the
// same graph with only a single edge between A and B, and a single edge
// between B and C.
//
// The graph must be valid for this operation to behave properly. If
// Validate() returns an error, the behavior is undefined and the results
// will likely be unexpected.
//
// Complexity: O(V(V+E)), or asymptotically O(VE)
func (g *AcyclicGraph) TransitiveReduction() {
// For each vertex u in graph g, do a DFS starting from each vertex
// v such that the edge (u,v) exists (v is a direct descendant of u).
//
// For each v-prime reachable from v, remove the edge (u, v-prime).
for _, u := range g.Vertices() {
uTargets := g.DownEdges(u)
vs := make([]Vertex, uTargets.Len())
for i, vRaw := range uTargets.List() {
vs[i] = vRaw.(Vertex)
}
g.depthFirstWalk(vs, func(v Vertex) error {
shared := uTargets.Intersection(g.DownEdges(v))
for _, raw := range shared.List() {
vPrime := raw.(Vertex)
g.RemoveEdge(BasicEdge(u, vPrime))
}
return nil
})
}
}
// Validate validates the DAG. A DAG is valid if it has a single root
// with no cycles.
func (g *AcyclicGraph) Validate() error {
@ -161,3 +199,37 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) error {
<-doneCh
return errs
}
// depthFirstWalk does a depth-first walk of the graph starting from
// the vertices in start. This is not exported now but it would make sense
// to export this publicly at some point.
func (g *AcyclicGraph) depthFirstWalk(start []Vertex, cb WalkFunc) error {
seen := make(map[Vertex]struct{})
frontier := make([]Vertex, len(start))
copy(frontier, start)
for len(frontier) > 0 {
// Pop the current vertex
n := len(frontier)
current := frontier[n-1]
frontier = frontier[:n-1]
// Check if we've seen this already and return...
if _, ok := seen[current]; ok {
continue
}
seen[current] = struct{}{}
// Visit the current node
if err := cb(current); err != nil {
return err
}
// Visit targets of this in reverse order.
targets := g.DownEdges(current).List()
for i := len(targets) - 1; i >= 0; i-- {
frontier = append(frontier, targets[i].(Vertex))
}
}
return nil
}

View File

@ -3,6 +3,7 @@ package dag
import (
"fmt"
"reflect"
"strings"
"sync"
"testing"
)
@ -48,6 +49,44 @@ func TestAcyclicGraphRoot_multiple(t *testing.T) {
}
}
func TestAyclicGraphTransReduction(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(2, 3))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestAyclicGraphTransReduction_more(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Add(4)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(1, 3))
g.Connect(BasicEdge(1, 4))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(2, 4))
g.Connect(BasicEdge(3, 4))
g.TransitiveReduction()
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphTransReductionMoreStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
func TestAcyclicGraphValidate(t *testing.T) {
var g AcyclicGraph
g.Add(1)
@ -156,3 +195,21 @@ func TestAcyclicGraphWalk_error(t *testing.T) {
t.Fatalf("bad: %#v", visits)
}
const testGraphTransReductionStr = `
1
2
2
3
3
`
const testGraphTransReductionMoreStr = `
1
2
2
3
3
4
4
`

View File

@ -36,6 +36,20 @@ func (s *Set) Include(v interface{}) bool {
return ok
}
// Intersection computes the set intersection with other.
func (s *Set) Intersection(other *Set) *Set {
result := new(Set)
if other != nil {
for _, v := range s.m {
if other.Include(v) {
result.Add(v)
}
}
}
return result
}
// Len is the number of items in the set.
func (s *Set) Len() int {
if s == nil {

View File

@ -109,5 +109,9 @@ func (b *BuiltinGraphBuilder) Steps() []GraphTransformer {
// Make sure we create one root
&RootTransformer{},
// Perform the transitive reduction to make our graph a bit
// more sane if possible (it usually is possible).
&TransitiveReductionTransformer{},
}
}

View File

@ -125,14 +125,10 @@ const testBasicGraphBuilderStr = `
const testBuiltinGraphBuilderBasicStr = `
aws_instance.db
aws_instance.db (destroy tainted)
provider.aws
aws_instance.db (destroy tainted)
aws_instance.web (destroy tainted)
provider.aws
aws_instance.web
aws_instance.db
aws_instance.web (destroy tainted)
provider.aws
aws_instance.web (destroy tainted)
provider.aws
provider.aws

View File

@ -0,0 +1,10 @@
resource "aws_instance" "A" {}
resource "aws_instance" "B" {
A = "${aws_instance.A.id}"
}
resource "aws_instance" "C" {
A = "${aws_instance.A.id}"
B = "${aws_instance.B.id}"
}

View File

@ -0,0 +1,20 @@
package terraform
// TransitiveReductionTransformer is a GraphTransformer that performs
// finds the transitive reduction of the graph. For a definition of
// transitive reduction, see Wikipedia.
type TransitiveReductionTransformer struct{}
func (t *TransitiveReductionTransformer) Transform(g *Graph) error {
// If the graph isn't valid, skip the transitive reduction.
// We don't error here because Terraform itself handles graph
// validation in a better way, or we assume it does.
if err := g.Validate(); err != nil {
return nil
}
// Do it
g.TransitiveReduction()
return nil
}

View File

@ -0,0 +1,39 @@
package terraform
import (
"strings"
"testing"
)
func TestTransitiveReductionTransformer(t *testing.T) {
mod := testModule(t, "transform-trans-reduce-basic")
g := Graph{Path: RootModulePath}
{
tf := &ConfigTransformer{Module: mod}
if err := tf.Transform(&g); err != nil {
t.Fatalf("err: %s", err)
}
}
{
transform := &TransitiveReductionTransformer{}
if err := transform.Transform(&g); err != nil {
t.Fatalf("err: %s", err)
}
}
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testTransformTransReduceBasicStr)
if actual != expected {
t.Fatalf("bad:\n\n%s", actual)
}
}
const testTransformTransReduceBasicStr = `
aws_instance.A
aws_instance.B
aws_instance.A
aws_instance.C
aws_instance.B
`