dag: use hashcodes to as map key to edge sets

This commit is contained in:
Mitchell Hashimoto 2015-10-27 11:58:34 -07:00
parent f1c6673e1b
commit 05794199af
3 changed files with 59 additions and 25 deletions

View File

@ -11,8 +11,8 @@ import (
type Graph struct {
vertices *Set
edges *Set
downEdges map[Vertex]*Set
upEdges map[Vertex]*Set
downEdges map[interface{}]*Set
upEdges map[interface{}]*Set
once sync.Once
}
@ -110,10 +110,10 @@ func (g *Graph) RemoveEdge(edge Edge) {
g.edges.Delete(edge)
// Delete the up/down edges
if s, ok := g.downEdges[edge.Source()]; ok {
if s, ok := g.downEdges[hashcode(edge.Source())]; ok {
s.Delete(edge.Target())
}
if s, ok := g.upEdges[edge.Target()]; ok {
if s, ok := g.upEdges[hashcode(edge.Target())]; ok {
s.Delete(edge.Source())
}
}
@ -121,13 +121,13 @@ func (g *Graph) RemoveEdge(edge Edge) {
// DownEdges returns the outward edges from the source Vertex v.
func (g *Graph) DownEdges(v Vertex) *Set {
g.once.Do(g.init)
return g.downEdges[v]
return g.downEdges[hashcode(v)]
}
// UpEdges returns the inward edges to the destination Vertex v.
func (g *Graph) UpEdges(v Vertex) *Set {
g.once.Do(g.init)
return g.upEdges[v]
return g.upEdges[hashcode(v)]
}
// Connect adds an edge with the given source and target. This is safe to
@ -139,9 +139,11 @@ func (g *Graph) Connect(edge Edge) {
source := edge.Source()
target := edge.Target()
sourceCode := hashcode(source)
targetCode := hashcode(target)
// Do we have this already? If so, don't add it again.
if s, ok := g.downEdges[source]; ok && s.Include(target) {
if s, ok := g.downEdges[sourceCode]; ok && s.Include(target) {
return
}
@ -149,18 +151,18 @@ func (g *Graph) Connect(edge Edge) {
g.edges.Add(edge)
// Add the down edge
s, ok := g.downEdges[source]
s, ok := g.downEdges[sourceCode]
if !ok {
s = new(Set)
g.downEdges[source] = s
g.downEdges[sourceCode] = s
}
s.Add(target)
// Add the up edge
s, ok = g.upEdges[target]
s, ok = g.upEdges[targetCode]
if !ok {
s = new(Set)
g.upEdges[target] = s
g.upEdges[targetCode] = s
}
s.Add(source)
}
@ -184,7 +186,7 @@ func (g *Graph) String() string {
// Write each node in order...
for _, name := range names {
v := mapping[name]
targets := g.downEdges[v]
targets := g.downEdges[hashcode(v)]
buf.WriteString(fmt.Sprintf("%s\n", name))
@ -207,8 +209,8 @@ func (g *Graph) String() string {
func (g *Graph) init() {
g.vertices = new(Set)
g.edges = new(Set)
g.downEdges = make(map[Vertex]*Set)
g.upEdges = make(map[Vertex]*Set)
g.downEdges = make(map[interface{}]*Set)
g.upEdges = make(map[interface{}]*Set)
}
// VertexName returns the name of a vertex.

View File

@ -1,6 +1,7 @@
package dag
import (
"fmt"
"strings"
"testing"
)
@ -79,6 +80,36 @@ func TestGraph_replaceSelf(t *testing.T) {
}
}
// This tests that connecting edges works based on custom Hashcode
// implementations for uniqueness.
func TestGraph_hashcode(t *testing.T) {
var g Graph
g.Add(&hashVertex{code: 1})
g.Add(&hashVertex{code: 2})
g.Add(&hashVertex{code: 3})
g.Connect(BasicEdge(
&hashVertex{code: 1},
&hashVertex{code: 3}))
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphBasicStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
type hashVertex struct {
code interface{}
}
func (v *hashVertex) Hashcode() interface{} {
return v.code
}
func (v *hashVertex) Name() string {
return fmt.Sprintf("%#v", v.code)
}
const testGraphBasicStr = `
1
3

View File

@ -17,22 +17,31 @@ type Hashable interface {
Hashcode() interface{}
}
// hashcode returns the hashcode used for set elements.
func hashcode(v interface{}) interface{} {
if h, ok := v.(Hashable); ok {
return h.Hashcode()
}
return v
}
// Add adds an item to the set
func (s *Set) Add(v interface{}) {
s.once.Do(s.init)
s.m[s.code(v)] = v
s.m[hashcode(v)] = v
}
// Delete removes an item from the set.
func (s *Set) Delete(v interface{}) {
s.once.Do(s.init)
delete(s.m, s.code(v))
delete(s.m, hashcode(v))
}
// Include returns true/false of whether a value is in the set.
func (s *Set) Include(v interface{}) bool {
s.once.Do(s.init)
_, ok := s.m[s.code(v)]
_, ok := s.m[hashcode(v)]
return ok
}
@ -73,14 +82,6 @@ func (s *Set) List() []interface{} {
return r
}
func (s *Set) code(v interface{}) interface{} {
if h, ok := v.(Hashable); ok {
return h.Hashcode()
}
return v
}
func (s *Set) init() {
s.m = make(map[interface{}]interface{})
}