command: Make provider installation interruptible

In earlier commits we started to make the installation codepath
context-aware so that it could be canceled in the event of a SIGINT, but
we didn't complete wiring that through the API of the getproviders
package.

Here we make the getproviders.Source interface methods, along with some
other functions that can make network requests, take a context.Context
argument and act appropriately if that context is cancelled.

The main providercache.Installer.EnsureProviderVersions method now also
has some context-awareness so that it can abort its work early if its
context reports any sort of error. That avoids waiting for the process
to wind through all of the remaining iterations of the various loops,
logging each request failure separately, and instead returns just
a single aggregate "canceled" error.

We can then set things up in the "terraform init" and
"terraform providers mirror" commands so that the context will be
cancelled if we get an interrupt signal, allowing provider installation
to abort early while still atomically completing any local-side effects
that may have started.
This commit is contained in:
Martin Atkins 2020-09-28 17:13:32 -07:00
parent f0ccee854c
commit 0b734a2803
23 changed files with 258 additions and 93 deletions

View File

@ -288,9 +288,9 @@ func (c *InitCommand) Run(args []string) int {
}
// Now that we have loaded all modules, check the module tree for missing providers.
providersOutput, providerDiags := c.getProviders(config, state, flagUpgrade, flagPluginPath)
providersOutput, providersAbort, providerDiags := c.getProviders(config, state, flagUpgrade, flagPluginPath)
diags = diags.Append(providerDiags)
if providerDiags.HasErrors() {
if providersAbort || providerDiags.HasErrors() {
c.showDiagnostics(diags)
return 1
}
@ -419,13 +419,13 @@ the backend configuration is present and valid.
// Load the complete module tree, and fetch any missing providers.
// This method outputs its own Ui.
func (c *InitCommand) getProviders(config *configs.Config, state *states.State, upgrade bool, pluginDirs []string) (output bool, diags tfdiags.Diagnostics) {
func (c *InitCommand) getProviders(config *configs.Config, state *states.State, upgrade bool, pluginDirs []string) (output, abort bool, diags tfdiags.Diagnostics) {
// First we'll collect all the provider dependencies we can see in the
// configuration and the state.
reqs, moreDiags := config.ProviderRequirements()
diags = diags.Append(moreDiags)
if moreDiags.HasErrors() {
return false, diags
return false, true, diags
}
stateReqs := make(getproviders.Requirements, 0)
if state != nil {
@ -574,6 +574,11 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
))
}
case getproviders.ErrRequestCanceled:
// We don't attribute cancellation to any particular operation,
// but rather just emit a single general message about it at
// the end, by checking ctx.Err().
default:
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
@ -637,6 +642,10 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
),
))
}
case getproviders.ErrRequestCanceled:
// We don't attribute cancellation to any particular operation,
// but rather just emit a single general message about it at
// the end, by checking ctx.Err().
default:
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
@ -688,10 +697,16 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
if upgrade {
mode = providercache.InstallUpgrades
}
// TODO: Use a context that will be cancelled when the Terraform
// process receives SIGINT.
ctx := evts.OnContext(context.TODO())
// Installation can be aborted by interruption signals
ctx, done := c.InterruptibleContext()
defer done()
ctx = evts.OnContext(ctx)
selected, err := inst.EnsureProviderVersions(ctx, reqs, mode)
if ctx.Err() == context.Canceled {
c.showDiagnostics(diags)
c.Ui.Error("Provider installation was canceled by an interrupt signal.")
return true, true, diags
}
if err != nil {
// Build a map of provider address to modules using the provider,
// so that we can later show diagnostics about affected modules
@ -707,7 +722,7 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
source := c.providerInstallSource()
for provider, fetchErr := range missingProviderErrors {
addr := addrs.NewLegacyProvider(provider.Type)
p, redirect, err := getproviders.LookupLegacyProvider(addr, source)
p, redirect, err := getproviders.LookupLegacyProvider(ctx, addr, source)
if err == nil {
if redirect.IsZero() {
foundProviders[provider] = p
@ -718,7 +733,7 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Failed to install provider",
fmt.Sprintf("Error while installing %s: %s", provider.ForDisplay(), fetchErr),
fmt.Sprintf("Error while installing %s: %s.", provider.ForDisplay(), fetchErr),
))
}
}
@ -837,7 +852,7 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
diags = diags.Append(err)
}
return true, diags
return true, true, diags
}
// If any providers have "floating" versions (completely unconstrained)
@ -864,7 +879,7 @@ func (c *InitCommand) getProviders(config *configs.Config, state *states.State,
}
}
return true, diags
return true, false, diags
}
func (c *InitCommand) populateProviderToReqs(reqs map[addrs.Provider][]*configs.ModuleRequirements, node *configs.ModuleRequirements) {

View File

@ -1298,6 +1298,54 @@ func TestInit_providerSource(t *testing.T) {
}
}
func TestInit_cancel(t *testing.T) {
// This test runs `terraform init` as if SIGINT (or similar on other
// platforms) were sent to it, testing that it is interruptible.
td := tempDir(t)
configDirName := "init-required-providers"
copy.CopyDir(testFixturePath(configDirName), filepath.Join(td, configDirName))
defer os.RemoveAll(td)
defer testChdir(t, td)()
providerSource, closeSrc := newMockProviderSource(t, map[string][]string{
"test": []string{"1.2.3", "1.2.4"},
"test-beta": []string{"1.2.4"},
"source": []string{"1.2.2", "1.2.3", "1.2.1"},
})
defer closeSrc()
// our shutdown channel is pre-closed so init will exit as soon as it
// starts a cancelable portion of the process.
shutdownCh := make(chan struct{})
close(shutdownCh)
ui := cli.NewMockUi()
m := Meta{
testingOverrides: metaOverridesForProvider(testProvider()),
Ui: ui,
ProviderSource: providerSource,
ShutdownCh: shutdownCh,
}
c := &InitCommand{
Meta: m,
}
args := []string{configDirName}
if code := c.Run(args); code == 0 {
t.Fatalf("succeeded; wanted error")
}
// Currently the first operation that is cancelable is provider
// installation, so our error message comes from there. If we
// make the earlier steps cancelable in future then it'd be
// expected for this particular message to change.
if got, want := ui.ErrorWriter.String(), `Provider installation was canceled by an interrupt signal`; !strings.Contains(got, want) {
t.Fatalf("wrong error message\nshould contain: %s\ngot:\n%s", want, got)
}
}
func TestInit_getUpgradePlugins(t *testing.T) {
// Create a temporary working directory that is empty
td := tempDir(t)

View File

@ -289,6 +289,33 @@ func (m *Meta) StdinPiped() bool {
return fi.Mode()&os.ModeNamedPipe != 0
}
// InterruptibleContext returns a context.Context that will be cancelled
// if the process is interrupted by a platform-specific interrupt signal.
//
// As usual with cancelable contexts, the caller must always call the given
// cancel function once all operations are complete in order to make sure
// that the context resources will still be freed even if there is no
// interruption.
func (m *Meta) InterruptibleContext() (context.Context, context.CancelFunc) {
base := context.Background()
if m.ShutdownCh == nil {
// If we're running in a unit testing context without a shutdown
// channel populated then we'll return an uncancelable channel.
return base, func() {}
}
ctx, cancel := context.WithCancel(base)
go func() {
select {
case <-m.ShutdownCh:
cancel()
case <-ctx.Done():
// finished without being interrupted
}
}()
return ctx, cancel
}
// RunOperation executes the given operation on the given backend, blocking
// until that operation completes or is interrupted, and then returns
// the RunningOperation object representing the completed or

View File

@ -113,6 +113,8 @@ func (c *ProvidersMirrorCommand) Run(args []string) int {
// infrequently to update a mirror, so it doesn't need to optimize away
// fetches of packages that might already be present.
ctx, cancel := c.InterruptibleContext()
defer cancel()
for provider, constraints := range reqs {
if provider.IsBuiltIn() {
c.Ui.Output(fmt.Sprintf("- Skipping %s because it is built in to Terraform CLI", provider.ForDisplay()))
@ -123,7 +125,7 @@ func (c *ProvidersMirrorCommand) Run(args []string) int {
// First we'll look for the latest version that matches the given
// constraint, which we'll then try to mirror for each target platform.
acceptable := versions.MeetingConstraints(constraints)
avail, _, err := source.AvailableVersions(provider)
avail, _, err := source.AvailableVersions(ctx, provider)
candidates := avail.Filter(acceptable)
if err == nil && len(candidates) == 0 {
err = fmt.Errorf("no releases match the given constraints %s", constraintsStr)
@ -144,7 +146,7 @@ func (c *ProvidersMirrorCommand) Run(args []string) int {
}
for _, platform := range platforms {
c.Ui.Output(fmt.Sprintf(" - Downloading package for %s...", platform.String()))
meta, err := source.PackageMeta(provider, selected, platform)
meta, err := source.PackageMeta(ctx, provider, selected, platform)
if err != nil {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,

View File

@ -210,6 +210,19 @@ func (err ErrQueryFailed) Unwrap() error {
return err.Wrapped
}
// ErrRequestCancelled is an error type used to indicate that an operation
// failed due to being cancelled via the given context.Context object.
//
// This error type doesn't include information about what was cancelled,
// because the expected treatment of this error type is to quickly abort and
// exit with minimal ceremony.
type ErrRequestCanceled struct {
}
func (err ErrRequestCanceled) Error() string {
return "request canceled"
}
// ErrIsNotExist returns true if and only if the given error is one of the
// errors from this package that represents an affirmative response that a
// requested object does not exist.

View File

@ -1,6 +1,8 @@
package getproviders
import (
"context"
"github.com/hashicorp/terraform/addrs"
)
@ -28,7 +30,7 @@ func NewFilesystemMirrorSource(baseDir string) *FilesystemMirrorSource {
// AvailableVersions scans the directory structure under the source's base
// directory for locally-mirrored packages for the given provider, returning
// a list of version numbers for the providers it found.
func (s *FilesystemMirrorSource) AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error) {
func (s *FilesystemMirrorSource) AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error) {
// s.allPackages is populated if scanAllVersions succeeds
err := s.scanAllVersions()
if err != nil {
@ -53,7 +55,7 @@ func (s *FilesystemMirrorSource) AvailableVersions(provider addrs.Provider) (Ver
// PackageMeta checks to see if the source's base directory contains a
// local copy of the distribution package for the given provider version on
// the given target, and returns the metadata about it if so.
func (s *FilesystemMirrorSource) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (s *FilesystemMirrorSource) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
// s.allPackages is populated if scanAllVersions succeeds
err := s.scanAllVersions()
if err != nil {

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"testing"
"github.com/apparentlymart/go-versions/versions"
@ -104,7 +105,7 @@ func TestFilesystemMirrorSourceAllAvailablePackages_invalid(t *testing.T) {
func TestFilesystemMirrorSourceAvailableVersions(t *testing.T) {
source := NewFilesystemMirrorSource("testdata/filesystem-mirror")
got, _, err := source.AvailableVersions(nullProvider)
got, _, err := source.AvailableVersions(context.Background(), nullProvider)
if err != nil {
t.Fatal(err)
}
@ -123,7 +124,10 @@ func TestFilesystemMirrorSourcePackageMeta(t *testing.T) {
t.Run("available platform", func(t *testing.T) {
source := NewFilesystemMirrorSource("testdata/filesystem-mirror")
got, err := source.PackageMeta(
nullProvider, versions.MustParseVersion("2.0.0"), Platform{"linux", "amd64"},
context.Background(),
nullProvider,
versions.MustParseVersion("2.0.0"),
Platform{"linux", "amd64"},
)
if err != nil {
t.Fatal(err)
@ -150,7 +154,10 @@ func TestFilesystemMirrorSourcePackageMeta(t *testing.T) {
// We'll request a version that does exist in the fixture directory,
// but for a platform that isn't supported.
_, err := source.PackageMeta(
nullProvider, versions.MustParseVersion("2.0.0"), Platform{"nonexist", "nonexist"},
context.Background(),
nullProvider,
versions.MustParseVersion("2.0.0"),
Platform{"nonexist", "nonexist"},
)
if err == nil {

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"encoding/json"
"fmt"
"io"
@ -82,7 +83,7 @@ func newHTTPMirrorSourceWithHTTPClient(baseURL *url.URL, creds svcauth.Credentia
// AvailableVersions retrieves the available versions for the given provider
// from the object's underlying HTTP mirror service.
func (s *HTTPMirrorSource) AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error) {
func (s *HTTPMirrorSource) AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error) {
log.Printf("[DEBUG] Querying available versions of provider %s at network mirror %s", provider.String(), s.baseURL.String())
endpointPath := path.Join(
@ -92,7 +93,7 @@ func (s *HTTPMirrorSource) AvailableVersions(provider addrs.Provider) (VersionLi
"index.json",
)
statusCode, body, finalURL, err := s.get(endpointPath)
statusCode, body, finalURL, err := s.get(ctx, endpointPath)
defer func() {
if body != nil {
body.Close()
@ -146,7 +147,7 @@ func (s *HTTPMirrorSource) AvailableVersions(provider addrs.Provider) (VersionLi
// PackageMeta retrieves metadata for the requested provider package
// from the object's underlying HTTP mirror service.
func (s *HTTPMirrorSource) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (s *HTTPMirrorSource) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
log.Printf("[DEBUG] Finding package URL for %s v%s on %s via network mirror %s", provider.String(), version.String(), target.String(), s.baseURL.String())
endpointPath := path.Join(
@ -156,7 +157,7 @@ func (s *HTTPMirrorSource) PackageMeta(provider addrs.Provider, version Version,
version.String()+".json",
)
statusCode, body, finalURL, err := s.get(endpointPath)
statusCode, body, finalURL, err := s.get(ctx, endpointPath)
defer func() {
if body != nil {
body.Close()
@ -287,7 +288,7 @@ func (s *HTTPMirrorSource) mirrorHostCredentials() (svcauth.HostCredentials, err
//
// If the "finalURL" return value is not empty then it's the URL that actually
// produced the returned response, possibly after following some redirects.
func (s *HTTPMirrorSource) get(relativePath string) (statusCode int, body io.ReadCloser, finalURL *url.URL, error error) {
func (s *HTTPMirrorSource) get(ctx context.Context, relativePath string) (statusCode int, body io.ReadCloser, finalURL *url.URL, error error) {
endpointPath, err := url.Parse(relativePath)
if err != nil {
// Should never happen because the caller should validate all of the
@ -300,6 +301,7 @@ func (s *HTTPMirrorSource) get(relativePath string) (statusCode int, body io.Rea
if err != nil {
return 0, nil, endpointURL, err
}
req = req.WithContext(ctx)
req.Request.Header.Set(terraformVersionHeader, version.String())
creds, err := s.mirrorHostCredentials()
if err != nil {
@ -361,6 +363,11 @@ func (s *HTTPMirrorSource) get(relativePath string) (statusCode int, body io.Rea
}
func (s *HTTPMirrorSource) errQueryFailed(provider addrs.Provider, err error) error {
if err == context.Canceled {
// This one has a special error type so that callers can
// handle it in a different way.
return ErrRequestCanceled{}
}
return ErrQueryFailed{
Provider: provider,
Wrapped: err,

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -41,7 +42,7 @@ func TestHTTPMirrorSource(t *testing.T) {
tosPlatform := Platform{OS: "tos", Arch: "m68k"}
t.Run("AvailableVersions for provider that exists", func(t *testing.T) {
got, _, err := source.AvailableVersions(existingProvider)
got, _, err := source.AvailableVersions(context.Background(), existingProvider)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -55,7 +56,7 @@ func TestHTTPMirrorSource(t *testing.T) {
}
})
t.Run("AvailableVersions for provider that doesn't exist", func(t *testing.T) {
_, _, err := source.AvailableVersions(missingProvider)
_, _, err := source.AvailableVersions(context.Background(), missingProvider)
switch err := err.(type) {
case ErrProviderNotFound:
if got, want := err.Provider, missingProvider; got != want {
@ -67,7 +68,7 @@ func TestHTTPMirrorSource(t *testing.T) {
})
t.Run("AvailableVersions without required credentials", func(t *testing.T) {
unauthSource := newHTTPMirrorSourceWithHTTPClient(baseURL, nil, httpClient)
_, _, err := unauthSource.AvailableVersions(existingProvider)
_, _, err := unauthSource.AvailableVersions(context.Background(), existingProvider)
switch err := err.(type) {
case ErrUnauthorized:
if got, want := string(err.Hostname), baseURL.Host; got != want {
@ -78,7 +79,7 @@ func TestHTTPMirrorSource(t *testing.T) {
}
})
t.Run("AvailableVersions when the response is a server error", func(t *testing.T) {
_, _, err := source.AvailableVersions(failingProvider)
_, _, err := source.AvailableVersions(context.Background(), failingProvider)
switch err := err.(type) {
case ErrQueryFailed:
if got, want := err.Provider, failingProvider; got != want {
@ -92,7 +93,7 @@ func TestHTTPMirrorSource(t *testing.T) {
}
})
t.Run("AvailableVersions for provider that redirects", func(t *testing.T) {
got, _, err := source.AvailableVersions(redirectingProvider)
got, _, err := source.AvailableVersions(context.Background(), redirectingProvider)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -104,14 +105,14 @@ func TestHTTPMirrorSource(t *testing.T) {
}
})
t.Run("AvailableVersions for provider that redirects too much", func(t *testing.T) {
_, _, err := source.AvailableVersions(redirectLoopProvider)
_, _, err := source.AvailableVersions(context.Background(), redirectLoopProvider)
if err == nil {
t.Fatalf("succeeded; expected error")
}
})
t.Run("PackageMeta for a version that exists and has a hash", func(t *testing.T) {
version := MustParseVersion("1.0.0")
got, err := source.PackageMeta(existingProvider, version, tosPlatform)
got, err := source.PackageMeta(context.Background(), existingProvider, version, tosPlatform)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -140,7 +141,7 @@ func TestHTTPMirrorSource(t *testing.T) {
})
t.Run("PackageMeta for a version that exists and has no hash", func(t *testing.T) {
version := MustParseVersion("1.0.1")
got, err := source.PackageMeta(existingProvider, version, tosPlatform)
got, err := source.PackageMeta(context.Background(), existingProvider, version, tosPlatform)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -159,7 +160,7 @@ func TestHTTPMirrorSource(t *testing.T) {
})
t.Run("PackageMeta for a version that exists but has no archives", func(t *testing.T) {
version := MustParseVersion("1.0.2-beta.1")
_, err := source.PackageMeta(existingProvider, version, tosPlatform)
_, err := source.PackageMeta(context.Background(), existingProvider, version, tosPlatform)
switch err := err.(type) {
case ErrPlatformNotSupported:
if got, want := err.Provider, existingProvider; got != want {
@ -177,7 +178,7 @@ func TestHTTPMirrorSource(t *testing.T) {
})
t.Run("PackageMeta with redirect to a version that exists", func(t *testing.T) {
version := MustParseVersion("1.0.0")
got, err := source.PackageMeta(redirectingProvider, version, tosPlatform)
got, err := source.PackageMeta(context.Background(), redirectingProvider, version, tosPlatform)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -198,7 +199,7 @@ func TestHTTPMirrorSource(t *testing.T) {
})
t.Run("PackageMeta when the response is a server error", func(t *testing.T) {
version := MustParseVersion("1.0.0")
_, err := source.PackageMeta(failingProvider, version, tosPlatform)
_, err := source.PackageMeta(context.Background(), failingProvider, version, tosPlatform)
switch err := err.(type) {
case ErrQueryFailed:
if got, want := err.Provider, failingProvider; got != want {

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"fmt"
svchost "github.com/hashicorp/terraform-svchost"
@ -25,7 +26,7 @@ import (
// configurations that don't include explicit provider source addresses. New
// configurations should not rely on it, and this fallback mechanism is
// likely to be removed altogether in a future Terraform version.
func LookupLegacyProvider(addr addrs.Provider, source Source) (addrs.Provider, addrs.Provider, error) {
func LookupLegacyProvider(ctx context.Context, addr addrs.Provider, source Source) (addrs.Provider, addrs.Provider, error) {
if addr.Namespace != "-" {
return addr, addrs.Provider{}, nil
}
@ -48,7 +49,7 @@ func LookupLegacyProvider(addr addrs.Provider, source Source) (addrs.Provider, a
return addrs.Provider{}, addrs.Provider{}, fmt.Errorf("unqualified provider type %q cannot be resolved because direct installation from %s is disabled in the CLI configuration; declare an explicit provider namespace for this provider", addr.Type, addr.Hostname)
}
defaultNamespace, redirectNamespace, err := regSource.LookupLegacyProviderNamespace(addr.Hostname, addr.Type)
defaultNamespace, redirectNamespace, err := regSource.LookupLegacyProviderNamespace(ctx, addr.Hostname, addr.Type)
if err != nil {
return addrs.Provider{}, addrs.Provider{}, err
}

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"strings"
"testing"
@ -12,6 +13,7 @@ func TestLookupLegacyProvider(t *testing.T) {
defer close()
got, gotMoved, err := LookupLegacyProvider(
context.Background(),
addrs.NewLegacyProvider("legacy"),
source,
)
@ -37,6 +39,7 @@ func TestLookupLegacyProvider_moved(t *testing.T) {
defer close()
got, gotMoved, err := LookupLegacyProvider(
context.Background(),
addrs.NewLegacyProvider("moved"),
source,
)
@ -67,6 +70,7 @@ func TestLookupLegacyProvider_invalidResponse(t *testing.T) {
defer close()
got, _, err := LookupLegacyProvider(
context.Background(),
addrs.NewLegacyProvider("invalid"),
source,
)
@ -84,6 +88,7 @@ func TestLookupLegacyProvider_unexpectedTypeChange(t *testing.T) {
defer close()
got, _, err := LookupLegacyProvider(
context.Background(),
addrs.NewLegacyProvider("changetype"),
source,
)

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"sync"
"github.com/hashicorp/terraform/addrs"
@ -56,7 +57,7 @@ func NewMemoizeSource(underlying Source) *MemoizeSource {
// AvailableVersions requests the available versions from the underlying source
// and caches them before returning them, or on subsequent calls returns the
// result directly from the cache.
func (s *MemoizeSource) AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error) {
func (s *MemoizeSource) AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -64,7 +65,7 @@ func (s *MemoizeSource) AvailableVersions(provider addrs.Provider) (VersionList,
return existing.VersionList, nil, existing.Err
}
ret, warnings, err := s.underlying.AvailableVersions(provider)
ret, warnings, err := s.underlying.AvailableVersions(ctx, provider)
s.availableVersions[provider] = memoizeAvailableVersionsRet{
VersionList: ret,
Err: err,
@ -76,7 +77,7 @@ func (s *MemoizeSource) AvailableVersions(provider addrs.Provider) (VersionList,
// PackageMeta requests package metadata from the underlying source and caches
// the result before returning it, or on subsequent calls returns the result
// directly from the cache.
func (s *MemoizeSource) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (s *MemoizeSource) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -89,7 +90,7 @@ func (s *MemoizeSource) PackageMeta(provider addrs.Provider, version Version, ta
return existing.PackageMeta, existing.Err
}
ret, err := s.underlying.PackageMeta(provider, version, target)
ret, err := s.underlying.PackageMeta(ctx, provider, version, target)
s.packageMetas[key] = memoizePackageMetaRet{
PackageMeta: ret,
Err: err,

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
@ -20,7 +21,7 @@ func TestMemoizeSource(t *testing.T) {
mock := NewMockSource([]PackageMeta{meta}, nil)
source := NewMemoizeSource(mock)
got, _, err := source.AvailableVersions(provider)
got, _, err := source.AvailableVersions(context.Background(), provider)
want := VersionList{version}
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -29,7 +30,7 @@ func TestMemoizeSource(t *testing.T) {
t.Fatalf("wrong result from first call to AvailableVersions\n%s", diff)
}
got, _, err = source.AvailableVersions(provider)
got, _, err = source.AvailableVersions(context.Background(), provider)
want = VersionList{version}
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -38,12 +39,12 @@ func TestMemoizeSource(t *testing.T) {
t.Fatalf("wrong result from second call to AvailableVersions\n%s", diff)
}
_, _, err = source.AvailableVersions(nonexistProvider)
_, _, err = source.AvailableVersions(context.Background(), nonexistProvider)
if want, ok := err.(ErrRegistryProviderNotKnown); !ok {
t.Fatalf("wrong error type from nonexist call:\ngot: %T\nwant: %T", err, want)
}
got, _, err = source.AvailableVersions(provider)
got, _, err = source.AvailableVersions(context.Background(), provider)
want = VersionList{version}
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -70,7 +71,7 @@ func TestMemoizeSource(t *testing.T) {
mock := NewMockSource([]PackageMeta{meta}, map[addrs.Provider]Warnings{warnProvider: {"WARNING!"}})
source := NewMemoizeSource(mock)
got, warns, err := source.AvailableVersions(warnProvider)
got, warns, err := source.AvailableVersions(context.Background(), warnProvider)
want := VersionList{version}
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -90,7 +91,7 @@ func TestMemoizeSource(t *testing.T) {
mock := NewMockSource([]PackageMeta{meta}, nil)
source := NewMemoizeSource(mock)
got, err := source.PackageMeta(provider, version, platform)
got, err := source.PackageMeta(context.Background(), provider, version, platform)
want := meta
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -99,7 +100,7 @@ func TestMemoizeSource(t *testing.T) {
t.Fatalf("wrong result from first call to PackageMeta\n%s", diff)
}
got, err = source.PackageMeta(provider, version, platform)
got, err = source.PackageMeta(context.Background(), provider, version, platform)
want = meta
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -108,16 +109,16 @@ func TestMemoizeSource(t *testing.T) {
t.Fatalf("wrong result from second call to PackageMeta\n%s", diff)
}
_, err = source.PackageMeta(nonexistProvider, version, platform)
_, err = source.PackageMeta(context.Background(), nonexistProvider, version, platform)
if want, ok := err.(ErrPlatformNotSupported); !ok {
t.Fatalf("wrong error type from nonexist provider call:\ngot: %T\nwant: %T", err, want)
}
_, err = source.PackageMeta(provider, version, nonexistPlatform)
_, err = source.PackageMeta(context.Background(), provider, version, nonexistPlatform)
if want, ok := err.(ErrPlatformNotSupported); !ok {
t.Fatalf("wrong error type from nonexist platform call:\ngot: %T\nwant: %T", err, want)
}
got, err = source.PackageMeta(provider, version, platform)
got, err = source.PackageMeta(context.Background(), provider, version, platform)
want = meta
if err != nil {
t.Fatalf("unexpected error: %s", err)
@ -143,11 +144,11 @@ func TestMemoizeSource(t *testing.T) {
mock := NewMockSource([]PackageMeta{meta}, nil)
source := NewMemoizeSource(mock)
_, _, err := source.AvailableVersions(nonexistProvider)
_, _, err := source.AvailableVersions(context.Background(), nonexistProvider)
if want, ok := err.(ErrRegistryProviderNotKnown); !ok {
t.Fatalf("wrong error type from first call:\ngot: %T\nwant: %T", err, want)
}
_, _, err = source.AvailableVersions(nonexistProvider)
_, _, err = source.AvailableVersions(context.Background(), nonexistProvider)
if want, ok := err.(ErrRegistryProviderNotKnown); !ok {
t.Fatalf("wrong error type from second call:\ngot: %T\nwant: %T", err, want)
}
@ -165,11 +166,11 @@ func TestMemoizeSource(t *testing.T) {
mock := NewMockSource([]PackageMeta{meta}, nil)
source := NewMemoizeSource(mock)
_, err := source.PackageMeta(nonexistProvider, version, platform)
_, err := source.PackageMeta(context.Background(), nonexistProvider, version, platform)
if want, ok := err.(ErrPlatformNotSupported); !ok {
t.Fatalf("wrong error type from first call:\ngot: %T\nwant: %T", err, want)
}
_, err = source.PackageMeta(nonexistProvider, version, platform)
_, err = source.PackageMeta(context.Background(), nonexistProvider, version, platform)
if want, ok := err.(ErrPlatformNotSupported); !ok {
t.Fatalf("wrong error type from second call:\ngot: %T\nwant: %T", err, want)
}

View File

@ -2,6 +2,7 @@ package getproviders
import (
"archive/zip"
"context"
"crypto/sha256"
"fmt"
"io"
@ -42,7 +43,7 @@ func NewMockSource(packages []PackageMeta, warns map[addrs.Provider]Warnings) *M
// AvailableVersions returns all of the versions of the given provider that
// are available in the fixed set of packages that were passed to
// NewMockSource when creating the receiving source.
func (s *MockSource) AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error) {
func (s *MockSource) AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error) {
s.calls = append(s.calls, []interface{}{"AvailableVersions", provider})
var ret VersionList
for _, pkg := range s.packages {
@ -78,7 +79,7 @@ func (s *MockSource) AvailableVersions(provider addrs.Provider) (VersionList, Wa
// always return the first one in the list, which may not match the behavior
// of other sources in an equivalent situation because it's a degenerate case
// with undefined results.
func (s *MockSource) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (s *MockSource) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
s.calls = append(s.calls, []interface{}{"PackageMeta", provider, version, target})
for _, pkg := range s.packages {

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"fmt"
"strings"
@ -28,7 +29,7 @@ var _ Source = MultiSource(nil)
// AvailableVersions retrieves all of the versions of the given provider
// that are available across all of the underlying selectors, while respecting
// each selector's matching patterns.
func (s MultiSource) AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error) {
func (s MultiSource) AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error) {
if len(s) == 0 { // Easy case: there can be no available versions
return nil, nil, nil
}
@ -42,7 +43,7 @@ func (s MultiSource) AvailableVersions(provider addrs.Provider) (VersionList, Wa
if !selector.CanHandleProvider(provider) {
continue // doesn't match the given patterns
}
thisSourceVersions, warningsResp, err := selector.Source.AvailableVersions(provider)
thisSourceVersions, warningsResp, err := selector.Source.AvailableVersions(ctx, provider)
switch err.(type) {
case nil:
// okay
@ -80,7 +81,7 @@ func (s MultiSource) AvailableVersions(provider addrs.Provider) (VersionList, Wa
// PackageMeta retrieves the package metadata for the requested provider package
// from the first selector that indicates availability of it.
func (s MultiSource) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (s MultiSource) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
if len(s) == 0 { // Easy case: no providers exist at all
return PackageMeta{}, ErrProviderNotFound{provider, s.sourcesForProvider(provider)}
}
@ -89,7 +90,7 @@ func (s MultiSource) PackageMeta(provider addrs.Provider, version Version, targe
if !selector.CanHandleProvider(provider) {
continue // doesn't match the given patterns
}
meta, err := selector.Source.PackageMeta(provider, version, target)
meta, err := selector.Source.PackageMeta(ctx, provider, version, target)
switch err.(type) {
case nil:
return meta, nil

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
@ -63,7 +64,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
// AvailableVersions produces the union of all versions available
// across all of the sources.
got, _, err := multi.AvailableVersions(addrs.NewDefaultProvider("foo"))
got, _, err := multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("foo"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -76,7 +77,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
t.Errorf("wrong result\n%s", diff)
}
_, _, err = multi.AvailableVersions(addrs.NewDefaultProvider("baz"))
_, _, err = multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("baz"))
if want, ok := err.(ErrRegistryProviderNotKnown); !ok {
t.Fatalf("wrong error type:\ngot: %T\nwant: %T", err, want)
}
@ -130,7 +131,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
},
}
got, _, err := multi.AvailableVersions(addrs.NewDefaultProvider("foo"))
got, _, err := multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("foo"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -142,7 +143,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
t.Errorf("wrong result\n%s", diff)
}
got, _, err = multi.AvailableVersions(addrs.NewDefaultProvider("bar"))
got, _, err = multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("bar"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -154,7 +155,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
t.Errorf("wrong result\n%s", diff)
}
_, _, err = multi.AvailableVersions(addrs.NewDefaultProvider("baz"))
_, _, err = multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("baz"))
if want, ok := err.(ErrRegistryProviderNotKnown); !ok {
t.Fatalf("wrong error type:\ngot: %T\nwant: %T", err, want)
}
@ -168,7 +169,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
{Source: s2},
}
_, _, err := multi.AvailableVersions(addrs.NewDefaultProvider("foo"))
_, _, err := multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("foo"))
if err == nil {
t.Fatal("expected error, got success")
}
@ -213,7 +214,7 @@ func TestMultiSourceAvailableVersions(t *testing.T) {
// AvailableVersions produces the union of all versions available
// across all of the sources.
got, warns, err := multi.AvailableVersions(addrs.NewDefaultProvider("bar"))
got, warns, err := multi.AvailableVersions(context.Background(), addrs.NewDefaultProvider("bar"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@ -293,6 +294,7 @@ func TestMultiSourcePackageMeta(t *testing.T) {
t.Run("only in s1", func(t *testing.T) {
got, err := multi.PackageMeta(
context.Background(),
addrs.NewDefaultProvider("foo"),
MustParseVersion("1.0.0"),
platform2,
@ -307,6 +309,7 @@ func TestMultiSourcePackageMeta(t *testing.T) {
})
t.Run("only in s2", func(t *testing.T) {
got, err := multi.PackageMeta(
context.Background(),
addrs.NewDefaultProvider("foo"),
MustParseVersion("1.2.0"),
platform1,
@ -321,6 +324,7 @@ func TestMultiSourcePackageMeta(t *testing.T) {
})
t.Run("in both", func(t *testing.T) {
got, err := multi.PackageMeta(
context.Background(),
addrs.NewDefaultProvider("foo"),
MustParseVersion("1.0.0"),
platform1,
@ -342,6 +346,7 @@ func TestMultiSourcePackageMeta(t *testing.T) {
})
t.Run("in neither", func(t *testing.T) {
_, err := multi.PackageMeta(
context.Background(),
addrs.NewDefaultProvider("nonexist"),
MustParseVersion("1.0.0"),
platform1,

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
@ -97,7 +98,7 @@ func newRegistryClient(baseURL *url.URL, creds svcauth.HostCredentials) *registr
// 404 Not Found to indicate that the namespace or provider type are not known,
// ErrUnauthorized if the registry responds with 401 or 403 status codes, or
// ErrQueryFailed for any other protocol or operational problem.
func (c *registryClient) ProviderVersions(addr addrs.Provider) (map[string][]string, []string, error) {
func (c *registryClient) ProviderVersions(ctx context.Context, addr addrs.Provider) (map[string][]string, []string, error) {
endpointPath, err := url.Parse(path.Join(addr.Namespace, addr.Type, "versions"))
if err != nil {
// Should never happen because we're constructing this from
@ -109,6 +110,7 @@ func (c *registryClient) ProviderVersions(addr addrs.Provider) (map[string][]str
if err != nil {
return nil, nil, err
}
req = req.WithContext(ctx)
c.addHeadersToRequest(req.Request)
resp, err := c.httpClient.Do(req)
@ -170,7 +172,7 @@ func (c *registryClient) ProviderVersions(addr addrs.Provider) (map[string][]str
// supported by this version of terraform.
// - ErrUnauthorized if the registry responds with 401 or 403 status codes
// - ErrQueryFailed for any other operational problem.
func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (c *registryClient) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
endpointPath, err := url.Parse(path.Join(
provider.Namespace,
provider.Type,
@ -190,6 +192,7 @@ func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, t
if err != nil {
return PackageMeta{}, err
}
req = req.WithContext(ctx)
c.addHeadersToRequest(req.Request)
resp, err := c.httpClient.Do(req)
@ -266,7 +269,7 @@ func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, t
if match == false {
// If the protocol version is not supported, try to find the closest
// matching version.
closest, err := c.findClosestProtocolCompatibleVersion(provider, version)
closest, err := c.findClosestProtocolCompatibleVersion(ctx, provider, version)
if err != nil {
return PackageMeta{}, err
}
@ -363,9 +366,9 @@ func (c *registryClient) PackageMeta(provider addrs.Provider, version Version, t
}
// findClosestProtocolCompatibleVersion searches for the provider version with the closest protocol match.
func (c *registryClient) findClosestProtocolCompatibleVersion(provider addrs.Provider, version Version) (Version, error) {
func (c *registryClient) findClosestProtocolCompatibleVersion(ctx context.Context, provider addrs.Provider, version Version) (Version, error) {
var match Version
available, _, err := c.ProviderVersions(provider)
available, _, err := c.ProviderVersions(ctx, provider)
if err != nil {
return UnspecifiedVersion, err
}
@ -412,7 +415,7 @@ FindMatch:
// This method exists only to allow compatibility with unqualified names
// in older configurations. New configurations should be written so as not to
// depend on it.
func (c *registryClient) LegacyProviderDefaultNamespace(typeName string) (string, string, error) {
func (c *registryClient) LegacyProviderDefaultNamespace(ctx context.Context, typeName string) (string, string, error) {
endpointPath, err := url.Parse(path.Join("-", typeName, "versions"))
if err != nil {
// Should never happen because we're constructing this from
@ -425,6 +428,7 @@ func (c *registryClient) LegacyProviderDefaultNamespace(typeName string) (string
if err != nil {
return "", "", err
}
req = req.WithContext(ctx)
c.addHeadersToRequest(req.Request)
// This is just to give us something to return in error messages. It's
@ -493,6 +497,11 @@ func (c *registryClient) addHeadersToRequest(req *http.Request) {
}
func (c *registryClient) errQueryFailed(provider addrs.Provider, err error) error {
if err == context.Canceled {
// This one has a special error type so that callers can
// handle it in a different way.
return ErrRequestCanceled{}
}
return ErrQueryFailed{
Provider: provider,
Wrapped: err,

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"encoding/json"
"fmt"
"log"
@ -340,7 +341,7 @@ func TestProviderVersions(t *testing.T) {
t.Fatal(err)
}
gotVersions, _, err := client.ProviderVersions(test.provider)
gotVersions, _, err := client.ProviderVersions(context.Background(), test.provider)
if err != nil {
if test.wantErr == "" {
@ -419,7 +420,7 @@ func TestFindClosestProtocolCompatibleVersion(t *testing.T) {
t.Fatal(err)
}
got, err := client.findClosestProtocolCompatibleVersion(test.provider, test.version)
got, err := client.findClosestProtocolCompatibleVersion(context.Background(), test.provider, test.version)
if err != nil {
if test.wantErr == "" {

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"fmt"
svchost "github.com/hashicorp/terraform-svchost"
@ -32,13 +33,13 @@ func NewRegistrySource(services *disco.Disco) *RegistrySource {
// ErrHostNoProviders, ErrHostUnreachable, ErrUnauthenticated,
// ErrProviderNotKnown, or ErrQueryFailed. Callers must be defensive and
// expect errors of other types too, to allow for future expansion.
func (s *RegistrySource) AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error) {
func (s *RegistrySource) AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error) {
client, err := s.registryClient(provider.Hostname)
if err != nil {
return nil, nil, err
}
versionsResponse, warnings, err := client.ProviderVersions(provider)
versionsResponse, warnings, err := client.ProviderVersions(ctx, provider)
if err != nil {
return nil, nil, err
}
@ -94,13 +95,13 @@ func (s *RegistrySource) AvailableVersions(provider addrs.Provider) (VersionList
// ErrHostNoProviders, ErrHostUnreachable, ErrUnauthenticated,
// ErrPlatformNotSupported, or ErrQueryFailed. Callers must be defensive and
// expect errors of other types too, to allow for future expansion.
func (s *RegistrySource) PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
func (s *RegistrySource) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
client, err := s.registryClient(provider.Hostname)
if err != nil {
return PackageMeta{}, err
}
return client.PackageMeta(provider, version, target)
return client.PackageMeta(ctx, provider, version, target)
}
// LookupLegacyProviderNamespace is a special method available only on
@ -118,12 +119,12 @@ func (s *RegistrySource) PackageMeta(provider addrs.Provider, version Version, t
// in older configurations. New configurations should be written so as not to
// depend on it, and this fallback mechanism will likely be removed altogether
// in a future Terraform version.
func (s *RegistrySource) LookupLegacyProviderNamespace(hostname svchost.Hostname, typeName string) (string, string, error) {
func (s *RegistrySource) LookupLegacyProviderNamespace(ctx context.Context, hostname svchost.Hostname, typeName string) (string, string, error) {
client, err := s.registryClient(hostname)
if err != nil {
return "", "", err
}
return client.LegacyProviderDefaultNamespace(typeName)
return client.LegacyProviderDefaultNamespace(ctx, typeName)
}
func (s *RegistrySource) registryClient(hostname svchost.Hostname) (*registryClient, error) {

View File

@ -1,6 +1,7 @@
package getproviders
import (
"context"
"fmt"
"regexp"
"strings"
@ -59,7 +60,7 @@ func TestSourceAvailableVersions(t *testing.T) {
for _, test := range tests {
t.Run(test.provider, func(t *testing.T) {
provider := addrs.MustParseProviderSourceString(test.provider)
gotVersions, _, err := source.AvailableVersions(provider)
gotVersions, _, err := source.AvailableVersions(context.Background(), provider)
if err != nil {
if test.wantErr == "" {
@ -95,7 +96,7 @@ func TestSourceAvailableVersions_warnings(t *testing.T) {
defer close()
provider := addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions")
_, warnings, err := source.AvailableVersions(provider)
_, warnings, err := source.AvailableVersions(context.Background(), provider)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
@ -209,7 +210,7 @@ func TestSourcePackageMeta(t *testing.T) {
version := versions.MustParseVersion(test.version)
got, err := source.PackageMeta(providerAddr, version, Platform{test.os, test.arch})
got, err := source.PackageMeta(context.Background(), providerAddr, version, Platform{test.os, test.arch})
if err != nil {
if test.wantErr == "" {

View File

@ -1,13 +1,15 @@
package getproviders
import (
"context"
"github.com/hashicorp/terraform/addrs"
)
// A Source can query a particular source for information about providers
// that are available to install.
type Source interface {
AvailableVersions(provider addrs.Provider) (VersionList, Warnings, error)
PackageMeta(provider addrs.Provider, version Version, target Platform) (PackageMeta, error)
AvailableVersions(ctx context.Context, provider addrs.Provider) (VersionList, Warnings, error)
PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error)
ForDisplay(provider addrs.Provider) string
}

View File

@ -130,11 +130,6 @@ func (i *Installer) SetUnmanagedProviderTypes(types map[addrs.Provider]struct{})
// in the final returned error value so callers should show either one or the
// other, and not both.
func (i *Installer) EnsureProviderVersions(ctx context.Context, reqs getproviders.Requirements, mode InstallMode) (getproviders.Selections, error) {
// FIXME: Currently the context isn't actually propagated into all of the
// other functions we call here, because they are not context-aware.
// Anything that could be making network requests here should take a
// context and ideally respond to the cancellation of that context.
errs := map[addrs.Provider]error{}
evts := installerEventsForContext(ctx)
@ -234,10 +229,17 @@ MightNeedProvider:
need := map[addrs.Provider]getproviders.Version{}
NeedProvider:
for provider, acceptableVersions := range mightNeed {
if err := ctx.Err(); err != nil {
// If our context has been cancelled or reached a timeout then
// we'll abort early, because subsequent operations against
// that context will fail immediately anyway.
return nil, err
}
if cb := evts.QueryPackagesBegin; cb != nil {
cb(provider, reqs[provider])
}
available, warnings, err := i.source.AvailableVersions(provider)
available, warnings, err := i.source.AvailableVersions(ctx, provider)
if err != nil {
// TODO: Consider retrying a few times for certain types of
// source errors that seem likely to be transient.
@ -277,6 +279,13 @@ NeedProvider:
authResults := map[addrs.Provider]*getproviders.PackageAuthenticationResult{} // record auth results for all successfully fetched providers
targetPlatform := i.targetDir.targetPlatform // we inherit this to behave correctly in unit tests
for provider, version := range need {
if err := ctx.Err(); err != nil {
// If our context has been cancelled or reached a timeout then
// we'll abort early, because subsequent operations against
// that context will fail immediately anyway.
return nil, err
}
if i.globalCacheDir != nil {
// Step 3a: If our global cache already has this version available then
// we'll just link it in.
@ -318,7 +327,7 @@ NeedProvider:
if cb := evts.FetchPackageMeta; cb != nil {
cb(provider, version)
}
meta, err := i.source.PackageMeta(provider, version, targetPlatform)
meta, err := i.source.PackageMeta(ctx, provider, version, targetPlatform)
if err != nil {
errs[provider] = err
if cb := evts.FetchPackageFailure; cb != nil {

View File

@ -40,6 +40,11 @@ func installFromHTTPURL(ctx context.Context, meta getproviders.PackageMeta, targ
}
resp, err := httpClient.Do(req)
if err != nil {
if ctx.Err() == context.Canceled {
// "context canceled" is not a user-friendly error message,
// so we'll return a more appropriate one here.
return nil, fmt.Errorf("provider download was interrupted")
}
return nil, err
}
defer resp.Body.Close()