diff --git a/helper/shadow/keyed_value.go b/helper/shadow/keyed_value.go index e685e6c3d..fe48cc761 100644 --- a/helper/shadow/keyed_value.go +++ b/helper/shadow/keyed_value.go @@ -24,8 +24,9 @@ func (w *KeyedValue) Close() error { w.closed = true // For all waiters, complete with ErrClosed - for _, w := range w.waiters { - w.SetValue(ErrClosed) + for k, val := range w.waiters { + val.SetValue(ErrClosed) + delete(w.waiters, k) } return nil @@ -50,6 +51,11 @@ func (w *KeyedValue) WaitForChange(k string) interface{} { w.lock.Lock() w.once.Do(w.init) + // If we're closed, we're closed + if w.closed { + return ErrClosed + } + // Check for an active waiter. If there isn't one, make it val := w.waiters[k] if val == nil { diff --git a/helper/shadow/keyed_value_test.go b/helper/shadow/keyed_value_test.go index 56d368190..098c54a59 100644 --- a/helper/shadow/keyed_value_test.go +++ b/helper/shadow/keyed_value_test.go @@ -223,3 +223,38 @@ func TestKeyedValueWaitForChange_initial(t *testing.T) { t.Fatalf("bad: %#v", val) } } + +func TestKeyedValueWaitForChange_closed(t *testing.T) { + var v KeyedValue + + // Start reading this should be blocking + valueCh := make(chan interface{}) + go func() { + valueCh <- v.WaitForChange("foo") + }() + + // We should not get the value + select { + case <-valueCh: + t.Fatal("shouldn't receive value") + case <-time.After(10 * time.Millisecond): + } + + // Close + v.Close() + + // Verify + val := <-valueCh + if val != ErrClosed { + t.Fatalf("bad: %#v", val) + } + + // Set a value + v.SetValue("foo", 42) + + // Try again + val = v.WaitForChange("foo") + if val != ErrClosed { + t.Fatalf("bad: %#v", val) + } +}