diff --git a/dag/graph.go b/dag/graph.go index 41d766f40..2572096ed 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -11,8 +11,8 @@ import ( type Graph struct { vertices *Set edges *Set - downEdges map[Vertex]*Set - upEdges map[Vertex]*Set + downEdges map[interface{}]*Set + upEdges map[interface{}]*Set once sync.Once } @@ -110,10 +110,10 @@ func (g *Graph) RemoveEdge(edge Edge) { g.edges.Delete(edge) // Delete the up/down edges - if s, ok := g.downEdges[edge.Source()]; ok { + if s, ok := g.downEdges[hashcode(edge.Source())]; ok { s.Delete(edge.Target()) } - if s, ok := g.upEdges[edge.Target()]; ok { + if s, ok := g.upEdges[hashcode(edge.Target())]; ok { s.Delete(edge.Source()) } } @@ -121,13 +121,13 @@ func (g *Graph) RemoveEdge(edge Edge) { // DownEdges returns the outward edges from the source Vertex v. func (g *Graph) DownEdges(v Vertex) *Set { g.once.Do(g.init) - return g.downEdges[v] + return g.downEdges[hashcode(v)] } // UpEdges returns the inward edges to the destination Vertex v. func (g *Graph) UpEdges(v Vertex) *Set { g.once.Do(g.init) - return g.upEdges[v] + return g.upEdges[hashcode(v)] } // Connect adds an edge with the given source and target. This is safe to @@ -139,9 +139,11 @@ func (g *Graph) Connect(edge Edge) { source := edge.Source() target := edge.Target() + sourceCode := hashcode(source) + targetCode := hashcode(target) // Do we have this already? If so, don't add it again. - if s, ok := g.downEdges[source]; ok && s.Include(target) { + if s, ok := g.downEdges[sourceCode]; ok && s.Include(target) { return } @@ -149,18 +151,18 @@ func (g *Graph) Connect(edge Edge) { g.edges.Add(edge) // Add the down edge - s, ok := g.downEdges[source] + s, ok := g.downEdges[sourceCode] if !ok { s = new(Set) - g.downEdges[source] = s + g.downEdges[sourceCode] = s } s.Add(target) // Add the up edge - s, ok = g.upEdges[target] + s, ok = g.upEdges[targetCode] if !ok { s = new(Set) - g.upEdges[target] = s + g.upEdges[targetCode] = s } s.Add(source) } @@ -184,7 +186,7 @@ func (g *Graph) String() string { // Write each node in order... for _, name := range names { v := mapping[name] - targets := g.downEdges[v] + targets := g.downEdges[hashcode(v)] buf.WriteString(fmt.Sprintf("%s\n", name)) @@ -207,8 +209,8 @@ func (g *Graph) String() string { func (g *Graph) init() { g.vertices = new(Set) g.edges = new(Set) - g.downEdges = make(map[Vertex]*Set) - g.upEdges = make(map[Vertex]*Set) + g.downEdges = make(map[interface{}]*Set) + g.upEdges = make(map[interface{}]*Set) } // VertexName returns the name of a vertex. diff --git a/dag/graph_test.go b/dag/graph_test.go index 8dd05e95e..eb3e40b3a 100644 --- a/dag/graph_test.go +++ b/dag/graph_test.go @@ -1,6 +1,7 @@ package dag import ( + "fmt" "strings" "testing" ) @@ -79,6 +80,36 @@ func TestGraph_replaceSelf(t *testing.T) { } } +// This tests that connecting edges works based on custom Hashcode +// implementations for uniqueness. +func TestGraph_hashcode(t *testing.T) { + var g Graph + g.Add(&hashVertex{code: 1}) + g.Add(&hashVertex{code: 2}) + g.Add(&hashVertex{code: 3}) + g.Connect(BasicEdge( + &hashVertex{code: 1}, + &hashVertex{code: 3})) + + actual := strings.TrimSpace(g.String()) + expected := strings.TrimSpace(testGraphBasicStr) + if actual != expected { + t.Fatalf("bad: %s", actual) + } +} + +type hashVertex struct { + code interface{} +} + +func (v *hashVertex) Hashcode() interface{} { + return v.code +} + +func (v *hashVertex) Name() string { + return fmt.Sprintf("%#v", v.code) +} + const testGraphBasicStr = ` 1 3 diff --git a/dag/set.go b/dag/set.go index 9cc0d98e2..d4b29226b 100644 --- a/dag/set.go +++ b/dag/set.go @@ -17,22 +17,31 @@ type Hashable interface { Hashcode() interface{} } +// hashcode returns the hashcode used for set elements. +func hashcode(v interface{}) interface{} { + if h, ok := v.(Hashable); ok { + return h.Hashcode() + } + + return v +} + // Add adds an item to the set func (s *Set) Add(v interface{}) { s.once.Do(s.init) - s.m[s.code(v)] = v + s.m[hashcode(v)] = v } // Delete removes an item from the set. func (s *Set) Delete(v interface{}) { s.once.Do(s.init) - delete(s.m, s.code(v)) + delete(s.m, hashcode(v)) } // Include returns true/false of whether a value is in the set. func (s *Set) Include(v interface{}) bool { s.once.Do(s.init) - _, ok := s.m[s.code(v)] + _, ok := s.m[hashcode(v)] return ok } @@ -73,14 +82,6 @@ func (s *Set) List() []interface{} { return r } -func (s *Set) code(v interface{}) interface{} { - if h, ok := v.(Hashable); ok { - return h.Hashcode() - } - - return v -} - func (s *Set) init() { s.m = make(map[interface{}]interface{}) }