diff --git a/internal/initwd/module_install.go b/internal/initwd/module_install.go index cbfd8e98c..38a44c262 100644 --- a/internal/initwd/module_install.go +++ b/internal/initwd/module_install.go @@ -438,8 +438,8 @@ func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest, log.Printf("[ERROR] %s from %s %s: %s", key, addr, latestMatch, err) diags = diags.Append(tfdiags.Sourceless( tfdiags.Error, - "Invalid response from remote module registry", - fmt.Sprintf("The remote registry at %s failed to return a download URL for %s %s.", hostname, addr, latestMatch), + "Error accessing remote module registry", + fmt.Sprintf("Failed to retrieve a download URL for %s %s from %s: %s", addr, latestMatch, hostname, err), )) return nil, nil, diags } diff --git a/registry/client.go b/registry/client.go index e8f7ac111..dcbb0debb 100644 --- a/registry/client.go +++ b/registry/client.go @@ -7,10 +7,13 @@ import ( "log" "net/http" "net/url" + "os" "path" + "strconv" "strings" "time" + "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/terraform-svchost" "github.com/hashicorp/terraform-svchost/disco" "github.com/hashicorp/terraform/httpclient" @@ -25,18 +28,34 @@ const ( requestTimeout = 10 * time.Second modulesServiceID = "modules.v1" providersServiceID = "providers.v1" + + // registryDiscoveryRetryEnvName is the name of the environment variable that + // can be configured to customize number of retries for module and provider + // discovery requests with the remote registry. + registryDiscoveryRetryEnvName = "TF_REGISTRY_DISCOVERY_RETRY" + defaultRetry = 1 ) +var discoveryRetry int + var tfVersion = version.String() +func init() { + configureDiscoveryRetry() +} + // Client provides methods to query Terraform Registries. type Client struct { // this is the client to be used for all requests. - client *http.Client + client *retryablehttp.Client // services is a required *disco.Disco, which may have services and // credentials pre-loaded. services *disco.Disco + + // retry is the number of retries the client will attempt for each request + // if it runs into a transient failure with the remote registry. + retry int } // NewClient returns a new initialized registry client. @@ -49,13 +68,18 @@ func NewClient(services *disco.Disco, client *http.Client) *Client { client = httpclient.New() client.Timeout = requestTimeout } + retryableClient := retryablehttp.NewClient() + retryableClient.HTTPClient = client + retryableClient.RetryMax = discoveryRetry + retryableClient.RequestLogHook = requestLogHook + retryableClient.ErrorHandler = maxRetryErrorHandler - services.Transport = client.Transport + services.Transport = retryableClient.HTTPClient.Transport services.SetUserAgent(httpclient.TerraformUserAgent(version.String())) return &Client{ - client: client, + client: retryableClient, services: services, } } @@ -93,12 +117,12 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions log.Printf("[DEBUG] fetching module versions from %q", service) - req, err := http.NewRequest("GET", service.String(), nil) + req, err := retryablehttp.NewRequest("GET", service.String(), nil) if err != nil { return nil, err } - c.addRequestCreds(host, req) + c.addRequestCreds(host, req.Request) req.Header.Set(xTerraformVersion, tfVersion) resp, err := c.client.Do(req) @@ -170,12 +194,12 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string, log.Printf("[DEBUG] looking up module location from %q", download) - req, err := http.NewRequest("GET", download.String(), nil) + req, err := retryablehttp.NewRequest("GET", download.String(), nil) if err != nil { return "", err } - c.addRequestCreds(host, req) + c.addRequestCreds(host, req.Request) req.Header.Set(xTerraformVersion, tfVersion) resp, err := c.client.Do(req) @@ -250,12 +274,12 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) ( log.Printf("[DEBUG] fetching provider versions from %q", service) - req, err := http.NewRequest("GET", service.String(), nil) + req, err := retryablehttp.NewRequest("GET", service.String(), nil) if err != nil { return nil, err } - c.addRequestCreds(host, req) + c.addRequestCreds(host, req.Request) req.Header.Set(xTerraformVersion, tfVersion) resp, err := c.client.Do(req) @@ -310,12 +334,12 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v log.Printf("[DEBUG] fetching provider location from %q", service) - req, err := http.NewRequest("GET", service.String(), nil) + req, err := retryablehttp.NewRequest("GET", service.String(), nil) if err != nil { return nil, err } - c.addRequestCreds(host, req) + c.addRequestCreds(host, req.Request) req.Header.Set(xTerraformVersion, tfVersion) resp, err := c.client.Do(req) @@ -343,3 +367,37 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v return &loc, nil } + +func configureDiscoveryRetry() { + discoveryRetry = defaultRetry + + if v := os.Getenv(registryDiscoveryRetryEnvName); v != "" { + retry, err := strconv.Atoi(v) + if err == nil && retry > 0 { + discoveryRetry = retry + } + } +} + +func requestLogHook(logger retryablehttp.Logger, req *http.Request, i int) { + if i > 0 { + logger.Printf("[INFO] Previous request to the remote registry failed, attempting retry.") + } +} + +func maxRetryErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) { + // Close the body per library instructions + if resp != nil { + resp.Body.Close() + } + + var errMsg string + if err != nil { + errMsg = fmt.Sprintf(" %s", err) + } + if numTries > 1 { + return resp, fmt.Errorf("the request failed after %d attempts, please try again later: %d%s", + numTries, resp.StatusCode, errMsg) + } + return resp, fmt.Errorf("the request failed, please try again later: %d%s", resp.StatusCode, errMsg) +} diff --git a/registry/client_test.go b/registry/client_test.go index 105205f94..f3a2bbcc3 100644 --- a/registry/client_test.go +++ b/registry/client_test.go @@ -1,7 +1,9 @@ package registry import ( + "context" "fmt" + "net/http" "os" "strings" "testing" @@ -14,6 +16,42 @@ import ( tfversion "github.com/hashicorp/terraform/version" ) +func TestConfigureDiscoveryRetry(t *testing.T) { + t.Run("default retry", func(t *testing.T) { + if discoveryRetry != defaultRetry { + t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry) + } + + rc := NewClient(nil, nil) + if rc.client.RetryMax != defaultRetry { + t.Fatalf("expected client retry %q, got %q", + defaultRetry, rc.client.RetryMax) + } + }) + + t.Run("configured retry", func(t *testing.T) { + defer func() { + os.Setenv(registryDiscoveryRetryEnvName, + os.Getenv(registryDiscoveryRetryEnvName)) + discoveryRetry = defaultRetry + }() + os.Setenv(registryDiscoveryRetryEnvName, "2") + + configureDiscoveryRetry() + expected := 2 + if discoveryRetry != expected { + t.Fatalf("expected retry %q, got %q", + expected, discoveryRetry) + } + + rc := NewClient(nil, nil) + if rc.client.RetryMax != expected { + t.Fatalf("expected client retry %q, got %q", + expected, rc.client.RetryMax) + } + }) +} + func TestLookupModuleVersions(t *testing.T) { server := test.Registry() defer server.Close() @@ -179,20 +217,31 @@ func TestAccLookupModuleVersions(t *testing.T) { } } -// the error should reference the config source exatly, not the discovered path. +// the error should reference the config source exactly, not the discovered path. func TestLookupLookupModuleError(t *testing.T) { server := test.Registry() defer server.Close() client := NewClient(test.Disco(server), nil) - // this should not be found in teh registry + // this should not be found in the registry src := "bad/local/path" mod, err := regsrc.ParseModuleSource(src) if err != nil { t.Fatal(err) } + // Instrument CheckRetry to make sure 404s are not retried + retries := 0 + oldCheck := client.client.CheckRetry + client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { + if retries > 0 { + t.Fatal("retried after module not found") + } + retries++ + return oldCheck(ctx, resp, err) + } + _, err = client.ModuleLocation(mod, "0.2.0") if err == nil { t.Fatal("expected error") @@ -204,6 +253,31 @@ func TestLookupLookupModuleError(t *testing.T) { } } +func TestLookupModuleRetryError(t *testing.T) { + server := test.RegistryRetryableErrorsServer() + defer server.Close() + + client := NewClient(test.Disco(server), nil) + + src := "example.com/test-versions/name/provider" + modsrc, err := regsrc.ParseModuleSource(src) + if err != nil { + t.Fatal(err) + } + resp, err := client.ModuleVersions(modsrc) + if err == nil { + t.Fatal("expected requests to exceed retry", err) + } + if resp != nil { + t.Fatal("unexpected response", *resp) + } + + // verify maxRetryErrorHandler handler returned the error + if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") { + t.Fatal("unexpected error, got:", err) + } +} + func TestLookupProviderVersions(t *testing.T) { server := test.Registry() defer server.Close() diff --git a/registry/test/mock_registry.go b/registry/test/mock_registry.go index e1b6249e3..924365f2b 100644 --- a/registry/test/mock_registry.go +++ b/registry/test/mock_registry.go @@ -363,3 +363,16 @@ func mockRegHandler() http.Handler { func Registry() *httptest.Server { return httptest.NewServer(mockRegHandler()) } + +// RegistryRetryableErrorsServer returns an httptest server that mocks out the +// registry API to return 502 errors. +func RegistryRetryableErrorsServer() *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/v1/modules/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "mocked server error", http.StatusBadGateway) + }) + mux.HandleFunc("/v1/providers/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "mocked server error", http.StatusBadGateway) + }) + return httptest.NewServer(mux) +}