diff --git a/communicator/communicator.go b/communicator/communicator.go index 5fa2631a4..3749a9f98 100644 --- a/communicator/communicator.go +++ b/communicator/communicator.go @@ -1,8 +1,11 @@ package communicator import ( + "context" "fmt" "io" + "log" + "sync/atomic" "time" "github.com/hashicorp/terraform/communicator/remote" @@ -51,3 +54,93 @@ func New(s *terraform.InstanceState) (Communicator, error) { return nil, fmt.Errorf("connection type '%s' not supported", connType) } } + +// maxBackoffDealy is the maximum delay between retry attempts +var maxBackoffDelay = 10 * time.Second +var initialBackoffDelay = time.Second + +type Fatal interface { + FatalError() error +} + +func Retry(ctx context.Context, f func() error) error { + // container for atomic error value + type errWrap struct { + E error + } + + // Try the function in a goroutine + var errVal atomic.Value + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + + delay := time.Duration(0) + for { + // If our context ended, we want to exit right away. + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + + // Try the function call + err := f() + + // return if we have no error, or a FatalError + done := false + switch e := err.(type) { + case nil: + done = true + case Fatal: + err = e.FatalError() + done = true + } + + errVal.Store(&errWrap{err}) + + if done { + return + } + + log.Printf("[WARN] retryable error: %v", err) + + delay *= 2 + + if delay == 0 { + delay = initialBackoffDelay + } + + if delay > maxBackoffDelay { + delay = maxBackoffDelay + } + + log.Printf("[INFO] sleeping for %s", delay) + } + }() + + // Wait for completion + select { + case <-ctx.Done(): + case <-doneCh: + } + + var lastErr error + // Check if we got an error executing + if ev, ok := errVal.Load().(errWrap); ok { + lastErr = ev.E + } + + // Check if we have a context error to check if we're interrupted or timeout + switch ctx.Err() { + case context.Canceled: + return fmt.Errorf("interrupted - last error: %v", lastErr) + case context.DeadlineExceeded: + return fmt.Errorf("timeout - last error: %v", lastErr) + } + + if lastErr != nil { + return lastErr + } + return nil +} diff --git a/communicator/communicator_test.go b/communicator/communicator_test.go index 33a91cd6f..659222421 100644 --- a/communicator/communicator_test.go +++ b/communicator/communicator_test.go @@ -1,7 +1,12 @@ package communicator import ( + "context" + "errors" + "io" + "net" "testing" + "time" "github.com/hashicorp/terraform/terraform" ) @@ -28,3 +33,66 @@ func TestCommunicator_new(t *testing.T) { t.Fatalf("err: %v", err) } } +func TestRetryFunc(t *testing.T) { + origMax := maxBackoffDelay + maxBackoffDelay = time.Second + origStart := initialBackoffDelay + initialBackoffDelay = 10 * time.Millisecond + + defer func() { + maxBackoffDelay = origMax + initialBackoffDelay = origStart + }() + + // succeed on the third try + errs := []error{io.EOF, &net.OpError{Err: errors.New("ERROR")}, nil} + count := 0 + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := Retry(ctx, func() error { + if count >= len(errs) { + return errors.New("failed to stop after nil error") + } + + err := errs[count] + count++ + + return err + }) + + if count != 3 { + t.Fatal("retry func should have been called 3 times") + } + + if err != nil { + t.Fatal(err) + } +} + +func TestRetryFuncBackoff(t *testing.T) { + origMax := maxBackoffDelay + maxBackoffDelay = time.Second + origStart := initialBackoffDelay + initialBackoffDelay = 100 * time.Millisecond + + defer func() { + maxBackoffDelay = origMax + initialBackoffDelay = origStart + }() + + count := 0 + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + Retry(ctx, func() error { + count++ + return io.EOF + }) + + if count > 4 { + t.Fatalf("retry func failed to backoff. called %d times", count) + } +}