From 77c445a8382255a77a9e0d48016b665d41d1d10a Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Thu, 2 Feb 2017 10:03:20 -0800 Subject: [PATCH] dag: Set difference --- dag/set.go | 22 +++++++++++++++++++ dag/set_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 dag/set_test.go diff --git a/dag/set.go b/dag/set.go index d4b29226b..3929c9d0e 100644 --- a/dag/set.go +++ b/dag/set.go @@ -48,6 +48,9 @@ func (s *Set) Include(v interface{}) bool { // Intersection computes the set intersection with other. func (s *Set) Intersection(other *Set) *Set { result := new(Set) + if s == nil { + return result + } if other != nil { for _, v := range s.m { if other.Include(v) { @@ -59,6 +62,25 @@ func (s *Set) Intersection(other *Set) *Set { return result } +// Difference returns a set with the elements that s has but +// other doesn't. +func (s *Set) Difference(other *Set) *Set { + result := new(Set) + if s != nil { + for k, v := range s.m { + var ok bool + if other != nil { + _, ok = other.m[k] + } + if !ok { + result.Add(v) + } + } + } + + return result +} + // Len is the number of items in the set. func (s *Set) Len() int { if s == nil { diff --git a/dag/set_test.go b/dag/set_test.go new file mode 100644 index 000000000..8aeae7073 --- /dev/null +++ b/dag/set_test.go @@ -0,0 +1,56 @@ +package dag + +import ( + "fmt" + "testing" +) + +func TestSetDifference(t *testing.T) { + cases := []struct { + Name string + A, B []interface{} + Expected []interface{} + }{ + { + "same", + []interface{}{1, 2, 3}, + []interface{}{3, 1, 2}, + []interface{}{}, + }, + + { + "A has extra elements", + []interface{}{1, 2, 3}, + []interface{}{3, 2}, + []interface{}{1}, + }, + + { + "B has extra elements", + []interface{}{1, 2, 3}, + []interface{}{3, 2, 1, 4}, + []interface{}{}, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%d-%s", i, tc.Name), func(t *testing.T) { + var one, two, expected Set + for _, v := range tc.A { + one.Add(v) + } + for _, v := range tc.B { + two.Add(v) + } + for _, v := range tc.Expected { + expected.Add(v) + } + + actual := one.Difference(&two) + match := actual.Intersection(&expected) + if match.Len() != expected.Len() { + t.Fatalf("bad: %#v", actual.List()) + } + }) + } +}