command: make module installation interruptible

Earlier work to make "terraform init" interruptible made the getproviders
package context-aware in order to allow provider installation to be cancelled.

Here we make a similar change for module installation, which is now also
cancellable with SIGINT. This involves plumbing context through initwd and
getmodules. Functions which can make network requests now include a context
parameter whose cancellation cancels those requests.

Since the module installation code is shared, "terraform get" is now
also interruptible during module installation.
This commit is contained in:
kmoe 2021-11-01 20:09:16 +00:00 committed by Katy Moe
parent ba4b6652fa
commit 40ec62c139
No known key found for this signature in database
GPG Key ID: 8C3780F6DCDDA885
19 changed files with 207 additions and 75 deletions

View File

@ -2,6 +2,7 @@ package command
import ( import (
"bytes" "bytes"
"context"
"crypto/md5" "crypto/md5"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -188,7 +189,7 @@ func testModuleWithSnapshot(t *testing.T, name string) (*configs.Config, *config
// sources only this ultimately just records all of the module paths // sources only this ultimately just records all of the module paths
// in a JSON file so that we can load them below. // in a JSON file so that we can load them below.
inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil)) inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil))
_, instDiags := inst.InstallModules(dir, true, initwd.ModuleInstallHooksImpl{}) _, instDiags := inst.InstallModules(context.Background(), dir, true, initwd.ModuleInstallHooksImpl{})
if instDiags.HasErrors() { if instDiags.HasErrors() {
t.Fatal(instDiags.Err()) t.Fatal(instDiags.Err())
} }

View File

@ -33,9 +33,9 @@ func (c *GetCommand) Run(args []string) int {
path = c.normalizePath(path) path = c.normalizePath(path)
diags := getModules(&c.Meta, path, update) abort, diags := getModules(&c.Meta, path, update)
c.showDiagnostics(diags) c.showDiagnostics(diags)
if diags.HasErrors() { if abort || diags.HasErrors() {
return 1 return 1
} }
@ -73,7 +73,7 @@ func (c *GetCommand) Synopsis() string {
return "Install or upgrade remote Terraform modules" return "Install or upgrade remote Terraform modules"
} }
func getModules(m *Meta, path string, upgrade bool) tfdiags.Diagnostics { func getModules(m *Meta, path string, upgrade bool) (abort bool, diags tfdiags.Diagnostics) {
hooks := uiModuleInstallHooks{ hooks := uiModuleInstallHooks{
Ui: m.Ui, Ui: m.Ui,
ShowLocalPaths: true, ShowLocalPaths: true,

View File

@ -79,3 +79,35 @@ func TestGet_update(t *testing.T) {
t.Fatalf("doesn't look like get: %s", output) t.Fatalf("doesn't look like get: %s", output)
} }
} }
func TestGet_cancel(t *testing.T) {
// This test runs `terraform get` as if SIGINT (or similar on other
// platforms) were sent to it, testing that it is interruptible.
wd := tempWorkingDirFixture(t, "init-registry-module")
defer testChdir(t, wd.RootModuleDir())()
// Our shutdown channel is pre-closed so init will exit as soon as it
// starts a cancelable portion of the process.
shutdownCh := make(chan struct{})
close(shutdownCh)
ui := cli.NewMockUi()
c := &GetCommand{
Meta: Meta{
testingOverrides: metaOverridesForProvider(testProvider()),
Ui: ui,
WorkingDir: wd,
ShutdownCh: shutdownCh,
},
}
args := []string{}
if code := c.Run(args); code == 0 {
t.Fatalf("succeeded; wanted error\n%s", ui.OutputWriter.String())
}
if got, want := ui.ErrorWriter.String(), `Module installation was canceled by an interrupt signal`; !strings.Contains(got, want) {
t.Fatalf("wrong error message\nshould contain: %s\ngot:\n%s", want, got)
}
}

View File

@ -115,9 +115,9 @@ func (c *InitCommand) Run(args []string) int {
ShowLocalPaths: false, // since they are in a weird location for init ShowLocalPaths: false, // since they are in a weird location for init
} }
initDiags := c.initDirFromModule(path, src, hooks) initDirFromModuleAbort, initDirFromModuleDiags := c.initDirFromModule(path, src, hooks)
diags = diags.Append(initDiags) diags = diags.Append(initDirFromModuleDiags)
if initDiags.HasErrors() { if initDirFromModuleAbort || initDirFromModuleDiags.HasErrors() {
c.showDiagnostics(diags) c.showDiagnostics(diags)
return 1 return 1
} }
@ -174,9 +174,9 @@ func (c *InitCommand) Run(args []string) int {
} }
if flagGet { if flagGet {
modsOutput, modsDiags := c.getModules(path, rootModEarly, flagUpgrade) modsOutput, modsAbort, modsDiags := c.getModules(path, rootModEarly, flagUpgrade)
diags = diags.Append(modsDiags) diags = diags.Append(modsDiags)
if modsDiags.HasErrors() { if modsAbort || modsDiags.HasErrors() {
c.showDiagnostics(diags) c.showDiagnostics(diags)
return 1 return 1
} }
@ -326,10 +326,10 @@ func (c *InitCommand) Run(args []string) int {
return 0 return 0
} }
func (c *InitCommand) getModules(path string, earlyRoot *tfconfig.Module, upgrade bool) (output bool, diags tfdiags.Diagnostics) { func (c *InitCommand) getModules(path string, earlyRoot *tfconfig.Module, upgrade bool) (output bool, abort bool, diags tfdiags.Diagnostics) {
if len(earlyRoot.ModuleCalls) == 0 { if len(earlyRoot.ModuleCalls) == 0 {
// Nothing to do // Nothing to do
return false, nil return false, false, nil
} }
if upgrade { if upgrade {
@ -342,8 +342,12 @@ func (c *InitCommand) getModules(path string, earlyRoot *tfconfig.Module, upgrad
Ui: c.Ui, Ui: c.Ui,
ShowLocalPaths: true, ShowLocalPaths: true,
} }
instDiags := c.installModules(path, upgrade, hooks)
diags = diags.Append(instDiags) installAbort, installDiags := c.installModules(path, upgrade, hooks)
diags = diags.Append(installDiags)
if installAbort || installDiags.HasErrors() {
return true, true, diags
}
// Since module installer has modified the module manifest on disk, we need // Since module installer has modified the module manifest on disk, we need
// to refresh the cache of it in the loader. // to refresh the cache of it in the loader.
@ -358,7 +362,7 @@ func (c *InitCommand) getModules(path string, earlyRoot *tfconfig.Module, upgrad
} }
} }
return true, diags return true, false, diags
} }
func (c *InitCommand) initCloud(root *configs.Module) (be backend.Backend, output bool, diags tfdiags.Diagnostics) { func (c *InitCommand) initCloud(root *configs.Module) (be backend.Backend, output bool, diags tfdiags.Diagnostics) {

View File

@ -1417,7 +1417,45 @@ func TestInit_providerSource(t *testing.T) {
} }
} }
func TestInit_cancel(t *testing.T) { func TestInit_cancelModules(t *testing.T) {
// This test runs `terraform init` as if SIGINT (or similar on other
// platforms) were sent to it, testing that it is interruptible.
td := tempDir(t)
testCopyDir(t, testFixturePath("init-registry-module"), td)
defer os.RemoveAll(td)
defer testChdir(t, td)()
// Our shutdown channel is pre-closed so init will exit as soon as it
// starts a cancelable portion of the process.
shutdownCh := make(chan struct{})
close(shutdownCh)
ui := cli.NewMockUi()
view, _ := testView(t)
m := Meta{
testingOverrides: metaOverridesForProvider(testProvider()),
Ui: ui,
View: view,
ShutdownCh: shutdownCh,
}
c := &InitCommand{
Meta: m,
}
args := []string{}
if code := c.Run(args); code == 0 {
t.Fatalf("succeeded; wanted error\n%s", ui.OutputWriter.String())
}
if got, want := ui.ErrorWriter.String(), `Module installation was canceled by an interrupt signal`; !strings.Contains(got, want) {
t.Fatalf("wrong error message\nshould contain: %s\ngot:\n%s", want, got)
}
}
func TestInit_cancelProviders(t *testing.T) {
// This test runs `terraform init` as if SIGINT (or similar on other // This test runs `terraform init` as if SIGINT (or similar on other
// platforms) were sent to it, testing that it is interruptible. // platforms) were sent to it, testing that it is interruptible.

View File

@ -164,26 +164,39 @@ func (m *Meta) loadHCLFile(filename string) (hcl.Body, tfdiags.Diagnostics) {
} }
// installModules reads a root module from the given directory and attempts // installModules reads a root module from the given directory and attempts
// recursively install all of its descendent modules. // recursively to install all of its descendent modules.
// //
// The given hooks object will be notified of installation progress, which // The given hooks object will be notified of installation progress, which
// can then be relayed to the end-user. The moduleUiInstallHooks type in // can then be relayed to the end-user. The uiModuleInstallHooks type in
// this package has a reasonable implementation for displaying notifications // this package has a reasonable implementation for displaying notifications
// via a provided cli.Ui. // via a provided cli.Ui.
func (m *Meta) installModules(rootDir string, upgrade bool, hooks initwd.ModuleInstallHooks) tfdiags.Diagnostics { func (m *Meta) installModules(rootDir string, upgrade bool, hooks initwd.ModuleInstallHooks) (abort bool, diags tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
rootDir = m.normalizePath(rootDir) rootDir = m.normalizePath(rootDir)
err := os.MkdirAll(m.modulesDir(), os.ModePerm) err := os.MkdirAll(m.modulesDir(), os.ModePerm)
if err != nil { if err != nil {
diags = diags.Append(fmt.Errorf("failed to create local modules directory: %s", err)) diags = diags.Append(fmt.Errorf("failed to create local modules directory: %s", err))
return diags return true, diags
} }
// FIXME: KEM: does returning the abort here change behaviour in a particular error
// case?
inst := m.moduleInstaller() inst := m.moduleInstaller()
_, moreDiags := inst.InstallModules(rootDir, upgrade, hooks)
// Installation can be aborted by interruption signals
ctx, done := m.InterruptibleContext()
defer done()
_, moreDiags := inst.InstallModules(ctx, rootDir, upgrade, hooks)
diags = diags.Append(moreDiags) diags = diags.Append(moreDiags)
return diags
if ctx.Err() == context.Canceled {
m.showDiagnostics(diags)
m.Ui.Error("Module installation was canceled by an interrupt signal.")
return true, diags
}
return false, diags
} }
// initDirFromModule initializes the given directory (which should be // initDirFromModule initializes the given directory (which should be
@ -192,15 +205,23 @@ func (m *Meta) installModules(rootDir string, upgrade bool, hooks initwd.ModuleI
// //
// Internally this runs similar steps to installModules. // Internally this runs similar steps to installModules.
// The given hooks object will be notified of installation progress, which // The given hooks object will be notified of installation progress, which
// can then be relayed to the end-user. The moduleUiInstallHooks type in // can then be relayed to the end-user. The uiModuleInstallHooks type in
// this package has a reasonable implementation for displaying notifications // this package has a reasonable implementation for displaying notifications
// via a provided cli.Ui. // via a provided cli.Ui.
func (m *Meta) initDirFromModule(targetDir string, addr string, hooks initwd.ModuleInstallHooks) tfdiags.Diagnostics { func (m *Meta) initDirFromModule(targetDir string, addr string, hooks initwd.ModuleInstallHooks) (abort bool, diags tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics // Installation can be aborted by interruption signals
ctx, done := m.InterruptibleContext()
defer done()
targetDir = m.normalizePath(targetDir) targetDir = m.normalizePath(targetDir)
moreDiags := initwd.DirFromModule(targetDir, m.modulesDir(), addr, m.registryClient(), hooks) moreDiags := initwd.DirFromModule(ctx, targetDir, m.modulesDir(), addr, m.registryClient(), hooks)
diags = diags.Append(moreDiags) diags = diags.Append(moreDiags)
return diags if ctx.Err() == context.Canceled {
m.showDiagnostics(diags)
m.Ui.Error("Module initialization was canceled by an interrupt signal.")
return true, diags
}
return false, diags
} }
// inputForSchema uses interactive prompts to try to populate any // inputForSchema uses interactive prompts to try to populate any

View File

@ -249,7 +249,7 @@ func (c *TestCommand) prepareSuiteDir(ctx context.Context, suiteName string) (te
os.MkdirAll(suiteDirs.ModulesDir, 0755) // if this fails then we'll ignore it and let InstallModules below fail instead os.MkdirAll(suiteDirs.ModulesDir, 0755) // if this fails then we'll ignore it and let InstallModules below fail instead
reg := c.registryClient() reg := c.registryClient()
moduleInst := initwd.NewModuleInstaller(suiteDirs.ModulesDir, reg) moduleInst := initwd.NewModuleInstaller(suiteDirs.ModulesDir, reg)
_, moreDiags := moduleInst.InstallModules(configDir, true, nil) _, moreDiags := moduleInst.InstallModules(ctx, configDir, true, nil)
diags = diags.Append(moreDiags) diags = diags.Append(moreDiags)
if diags.HasErrors() { if diags.HasErrors() {
return suiteDirs, diags return suiteDirs, diags

View File

@ -0,0 +1,4 @@
module "foo" {
source = "registry.does.not.exist/example_corp/foo/bar"
version = "0.1.0"
}

View File

@ -1,6 +1,7 @@
package getmodules package getmodules
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -119,7 +120,7 @@ type reusingGetter map[string]string
// end-user-actionable error messages. At this time we do not have any // end-user-actionable error messages. At this time we do not have any
// reasonable way to improve these error messages at this layer because // reasonable way to improve these error messages at this layer because
// the underlying errors are not separately recognizable. // the underlying errors are not separately recognizable.
func (g reusingGetter) getWithGoGetter(instPath, packageAddr string) error { func (g reusingGetter) getWithGoGetter(ctx context.Context, instPath, packageAddr string) error {
var err error var err error
if prevDir, exists := g[packageAddr]; exists { if prevDir, exists := g[packageAddr]; exists {
@ -144,6 +145,7 @@ func (g reusingGetter) getWithGoGetter(instPath, packageAddr string) error {
Detectors: goGetterNoDetectors, // our caller should've already done detection Detectors: goGetterNoDetectors, // our caller should've already done detection
Decompressors: goGetterDecompressors, Decompressors: goGetterDecompressors,
Getters: goGetterGetters, Getters: goGetterGetters,
Ctx: ctx,
} }
err = client.Get() err = client.Get()
if err != nil { if err != nil {

View File

@ -1,5 +1,9 @@
package getmodules package getmodules
import (
"context"
)
// PackageFetcher is a low-level utility for fetching remote module packages // PackageFetcher is a low-level utility for fetching remote module packages
// into local filesystem directories in preparation for use by higher-level // into local filesystem directories in preparation for use by higher-level
// module installer functionality implemented elsewhere. // module installer functionality implemented elsewhere.
@ -35,6 +39,6 @@ func NewPackageFetcher() *PackageFetcher {
// a module source address which includes a subdirectory portion then the // a module source address which includes a subdirectory portion then the
// caller must resolve that itself, possibly with the help of the // caller must resolve that itself, possibly with the help of the
// getmodules.SplitPackageSubdir and getmodules.ExpandSubdirGlobs functions. // getmodules.SplitPackageSubdir and getmodules.ExpandSubdirGlobs functions.
func (f *PackageFetcher) FetchPackage(instDir string, packageAddr string) error { func (f *PackageFetcher) FetchPackage(ctx context.Context, instDir string, packageAddr string) error {
return f.getter.getWithGoGetter(instDir, packageAddr) return f.getter.getWithGoGetter(ctx, instDir, packageAddr)
} }

View File

@ -1,6 +1,7 @@
package initwd package initwd
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -41,7 +42,7 @@ const initFromModuleRootKeyPrefix = initFromModuleRootCallName + "."
// references using ../ from that module to be unresolvable. Error diagnostics // references using ../ from that module to be unresolvable. Error diagnostics
// are produced in that case, to prompt the user to rewrite the source strings // are produced in that case, to prompt the user to rewrite the source strings
// to be absolute references to the original remote module. // to be absolute references to the original remote module.
func DirFromModule(rootDir, modulesDir, sourceAddr string, reg *registry.Client, hooks ModuleInstallHooks) tfdiags.Diagnostics { func DirFromModule(ctx context.Context, rootDir, modulesDir, sourceAddr string, reg *registry.Client, hooks ModuleInstallHooks) tfdiags.Diagnostics {
var diags tfdiags.Diagnostics var diags tfdiags.Diagnostics
// The way this function works is pretty ugly, but we accept it because // The way this function works is pretty ugly, but we accept it because
@ -144,7 +145,7 @@ func DirFromModule(rootDir, modulesDir, sourceAddr string, reg *registry.Client,
Wrapped: hooks, Wrapped: hooks,
} }
fetcher := getmodules.NewPackageFetcher() fetcher := getmodules.NewPackageFetcher()
_, instDiags := inst.installDescendentModules(fakeRootModule, rootDir, instManifest, true, wrapHooks, fetcher) _, instDiags := inst.installDescendentModules(ctx, fakeRootModule, rootDir, instManifest, true, wrapHooks, fetcher)
diags = append(diags, instDiags...) diags = append(diags, instDiags...)
if instDiags.HasErrors() { if instDiags.HasErrors() {
return diags return diags

View File

@ -1,6 +1,7 @@
package initwd package initwd
import ( import (
"context"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -38,7 +39,7 @@ func TestDirFromModule_registry(t *testing.T) {
hooks := &testInstallHooks{} hooks := &testInstallHooks{}
reg := registry.NewClient(nil, nil) reg := registry.NewClient(nil, nil)
diags := DirFromModule(dir, modsDir, "hashicorp/module-installer-acctest/aws//examples/main", reg, hooks) diags := DirFromModule(context.Background(), dir, modsDir, "hashicorp/module-installer-acctest/aws//examples/main", reg, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
v := version.Must(version.NewVersion("0.0.2")) v := version.Must(version.NewVersion("0.0.2"))
@ -154,7 +155,7 @@ func TestDirFromModule_submodules(t *testing.T) {
} }
modInstallDir := filepath.Join(dir, ".terraform/modules") modInstallDir := filepath.Join(dir, ".terraform/modules")
diags := DirFromModule(dir, modInstallDir, fromModuleDir, nil, hooks) diags := DirFromModule(context.Background(), dir, modInstallDir, fromModuleDir, nil, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
wantCalls := []testInstallHookCall{ wantCalls := []testInstallHookCall{
{ {
@ -248,7 +249,7 @@ func TestDirFromModule_rel_submodules(t *testing.T) {
modInstallDir := ".terraform/modules" modInstallDir := ".terraform/modules"
sourceDir := "../local-modules" sourceDir := "../local-modules"
diags := DirFromModule(".", modInstallDir, sourceDir, nil, hooks) diags := DirFromModule(context.Background(), ".", modInstallDir, sourceDir, nil, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
wantCalls := []testInstallHookCall{ wantCalls := []testInstallHookCall{
{ {

View File

@ -1,6 +1,8 @@
package initwd package initwd
import ( import (
"context"
"errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -75,7 +77,7 @@ func NewModuleInstaller(modsDir string, reg *registry.Client) *ModuleInstaller {
// If successful (the returned diagnostics contains no errors) then the // If successful (the returned diagnostics contains no errors) then the
// first return value is the early configuration tree that was constructed by // first return value is the early configuration tree that was constructed by
// the installation process. // the installation process.
func (i *ModuleInstaller) InstallModules(rootDir string, upgrade bool, hooks ModuleInstallHooks) (*earlyconfig.Config, tfdiags.Diagnostics) { func (i *ModuleInstaller) InstallModules(ctx context.Context, rootDir string, upgrade bool, hooks ModuleInstallHooks) (*earlyconfig.Config, tfdiags.Diagnostics) {
log.Printf("[TRACE] ModuleInstaller: installing child modules for %s into %s", rootDir, i.modsDir) log.Printf("[TRACE] ModuleInstaller: installing child modules for %s into %s", rootDir, i.modsDir)
rootMod, diags := earlyconfig.LoadModule(rootDir) rootMod, diags := earlyconfig.LoadModule(rootDir)
@ -94,13 +96,13 @@ func (i *ModuleInstaller) InstallModules(rootDir string, upgrade bool, hooks Mod
} }
fetcher := getmodules.NewPackageFetcher() fetcher := getmodules.NewPackageFetcher()
cfg, instDiags := i.installDescendentModules(rootMod, rootDir, manifest, upgrade, hooks, fetcher) cfg, instDiags := i.installDescendentModules(ctx, rootMod, rootDir, manifest, upgrade, hooks, fetcher)
diags = append(diags, instDiags...) diags = append(diags, instDiags...)
return cfg, diags return cfg, diags
} }
func (i *ModuleInstaller) installDescendentModules(rootMod *tfconfig.Module, rootDir string, manifest modsdir.Manifest, upgrade bool, hooks ModuleInstallHooks, fetcher *getmodules.PackageFetcher) (*earlyconfig.Config, tfdiags.Diagnostics) { func (i *ModuleInstaller) installDescendentModules(ctx context.Context, rootMod *tfconfig.Module, rootDir string, manifest modsdir.Manifest, upgrade bool, hooks ModuleInstallHooks, fetcher *getmodules.PackageFetcher) (*earlyconfig.Config, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics var diags tfdiags.Diagnostics
if hooks == nil { if hooks == nil {
@ -209,13 +211,13 @@ func (i *ModuleInstaller) installDescendentModules(rootMod *tfconfig.Module, roo
case addrs.ModuleSourceRegistry: case addrs.ModuleSourceRegistry:
log.Printf("[TRACE] ModuleInstaller: %s is a registry module at %s", key, addr.String()) log.Printf("[TRACE] ModuleInstaller: %s is a registry module at %s", key, addr.String())
mod, v, mDiags := i.installRegistryModule(req, key, instPath, addr, manifest, hooks, fetcher) mod, v, mDiags := i.installRegistryModule(ctx, req, key, instPath, addr, manifest, hooks, fetcher)
diags = append(diags, mDiags...) diags = append(diags, mDiags...)
return mod, v, diags return mod, v, diags
case addrs.ModuleSourceRemote: case addrs.ModuleSourceRemote:
log.Printf("[TRACE] ModuleInstaller: %s address %q will be handled by go-getter", key, addr.String()) log.Printf("[TRACE] ModuleInstaller: %s address %q will be handled by go-getter", key, addr.String())
mod, mDiags := i.installGoGetterModule(req, key, instPath, manifest, hooks, fetcher) mod, mDiags := i.installGoGetterModule(ctx, req, key, instPath, manifest, hooks, fetcher)
diags = append(diags, mDiags...) diags = append(diags, mDiags...)
return mod, nil, diags return mod, nil, diags
@ -301,7 +303,7 @@ func (i *ModuleInstaller) installLocalModule(req *earlyconfig.ModuleRequest, key
return mod, diags return mod, diags
} }
func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest, key string, instPath string, addr addrs.ModuleSourceRegistry, manifest modsdir.Manifest, hooks ModuleInstallHooks, fetcher *getmodules.PackageFetcher) (*tfconfig.Module, *version.Version, tfdiags.Diagnostics) { func (i *ModuleInstaller) installRegistryModule(ctx context.Context, req *earlyconfig.ModuleRequest, key string, instPath string, addr addrs.ModuleSourceRegistry, manifest modsdir.Manifest, hooks ModuleInstallHooks, fetcher *getmodules.PackageFetcher) (*tfconfig.Module, *version.Version, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics var diags tfdiags.Diagnostics
hostname := addr.PackageAddr.Host hostname := addr.PackageAddr.Host
@ -324,7 +326,7 @@ func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest,
} else { } else {
var err error var err error
log.Printf("[DEBUG] %s listing available versions of %s at %s", key, addr, hostname) log.Printf("[DEBUG] %s listing available versions of %s at %s", key, addr, hostname)
resp, err = reg.ModuleVersions(regsrcAddr) resp, err = reg.ModuleVersions(ctx, regsrcAddr)
if err != nil { if err != nil {
if registry.IsModuleNotFound(err) { if registry.IsModuleNotFound(err) {
diags = diags.Append(tfdiags.Sourceless( diags = diags.Append(tfdiags.Sourceless(
@ -332,6 +334,12 @@ func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest,
"Module not found", "Module not found",
fmt.Sprintf("Module %q (from %s:%d) cannot be found in the module registry at %s.", req.Name, req.CallPos.Filename, req.CallPos.Line, hostname), fmt.Sprintf("Module %q (from %s:%d) cannot be found in the module registry at %s.", req.Name, req.CallPos.Filename, req.CallPos.Line, hostname),
)) ))
} else if errors.Is(err, context.Canceled) {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Module installation was interrupted",
fmt.Sprintf("Received interrupt signal while retrieving available versions for module %q.", req.Name),
))
} else { } else {
diags = diags.Append(tfdiags.Sourceless( diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error, tfdiags.Error,
@ -423,7 +431,7 @@ func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest,
// first check the cache for the download URL // first check the cache for the download URL
moduleAddr := moduleVersion{module: packageAddr, version: latestMatch.String()} moduleAddr := moduleVersion{module: packageAddr, version: latestMatch.String()}
if _, exists := i.registryPackageSources[moduleAddr]; !exists { if _, exists := i.registryPackageSources[moduleAddr]; !exists {
realAddrRaw, err := reg.ModuleLocation(regsrcAddr, latestMatch.String()) realAddrRaw, err := reg.ModuleLocation(ctx, regsrcAddr, latestMatch.String())
if err != nil { if err != nil {
log.Printf("[ERROR] %s from %s %s: %s", key, addr, latestMatch, err) log.Printf("[ERROR] %s from %s %s: %s", key, addr, latestMatch, err)
diags = diags.Append(tfdiags.Sourceless( diags = diags.Append(tfdiags.Sourceless(
@ -463,7 +471,15 @@ func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest,
log.Printf("[TRACE] ModuleInstaller: %s %s %s is available at %q", key, packageAddr, latestMatch, dlAddr.PackageAddr) log.Printf("[TRACE] ModuleInstaller: %s %s %s is available at %q", key, packageAddr, latestMatch, dlAddr.PackageAddr)
err := fetcher.FetchPackage(instPath, dlAddr.PackageAddr.String()) err := fetcher.FetchPackage(ctx, instPath, dlAddr.PackageAddr.String())
if errors.Is(err, context.Canceled) {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Module download was interrupted",
fmt.Sprintf("Interrupt signal received when downloading module %s.", addr),
))
return nil, nil, diags
}
if err != nil { if err != nil {
// Errors returned by go-getter have very inconsistent quality as // Errors returned by go-getter have very inconsistent quality as
// end-user error messages, but for now we're accepting that because // end-user error messages, but for now we're accepting that because
@ -519,7 +535,7 @@ func (i *ModuleInstaller) installRegistryModule(req *earlyconfig.ModuleRequest,
return mod, latestMatch, diags return mod, latestMatch, diags
} }
func (i *ModuleInstaller) installGoGetterModule(req *earlyconfig.ModuleRequest, key string, instPath string, manifest modsdir.Manifest, hooks ModuleInstallHooks, fetcher *getmodules.PackageFetcher) (*tfconfig.Module, tfdiags.Diagnostics) { func (i *ModuleInstaller) installGoGetterModule(ctx context.Context, req *earlyconfig.ModuleRequest, key string, instPath string, manifest modsdir.Manifest, hooks ModuleInstallHooks, fetcher *getmodules.PackageFetcher) (*tfconfig.Module, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics var diags tfdiags.Diagnostics
// Report up to the caller that we're about to start downloading. // Report up to the caller that we're about to start downloading.
@ -536,7 +552,7 @@ func (i *ModuleInstaller) installGoGetterModule(req *earlyconfig.ModuleRequest,
return nil, diags return nil, diags
} }
err := fetcher.FetchPackage(instPath, packageAddr.String()) err := fetcher.FetchPackage(ctx, instPath, packageAddr.String())
if err != nil { if err != nil {
// go-getter generates a poor error for an invalid relative path, so // go-getter generates a poor error for an invalid relative path, so
// we'll detect that case and generate a better one. // we'll detect that case and generate a better one.

View File

@ -2,6 +2,7 @@ package initwd
import ( import (
"bytes" "bytes"
"context"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -39,7 +40,7 @@ func TestModuleInstaller(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
wantCalls := []testInstallHookCall{ wantCalls := []testInstallHookCall{
@ -100,7 +101,7 @@ func TestModuleInstaller_error(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
if !diags.HasErrors() { if !diags.HasErrors() {
t.Fatal("expected error") t.Fatal("expected error")
@ -135,7 +136,7 @@ func TestModuleInstaller_packageEscapeError(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
if !diags.HasErrors() { if !diags.HasErrors() {
t.Fatal("expected error") t.Fatal("expected error")
@ -170,7 +171,7 @@ func TestModuleInstaller_explicitPackageBoundary(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
if diags.HasErrors() { if diags.HasErrors() {
t.Fatalf("unexpected errors\n%s", diags.Err().Error()) t.Fatalf("unexpected errors\n%s", diags.Err().Error())
@ -186,7 +187,7 @@ func TestModuleInstaller_invalid_version_constraint_error(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
if !diags.HasErrors() { if !diags.HasErrors() {
t.Fatal("expected error") t.Fatal("expected error")
@ -204,7 +205,7 @@ func TestModuleInstaller_invalidVersionConstraintGetter(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
if !diags.HasErrors() { if !diags.HasErrors() {
t.Fatal("expected error") t.Fatal("expected error")
@ -222,7 +223,7 @@ func TestModuleInstaller_invalidVersionConstraintLocal(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
if !diags.HasErrors() { if !diags.HasErrors() {
t.Fatal("expected error") t.Fatal("expected error")
@ -240,7 +241,7 @@ func TestModuleInstaller_symlink(t *testing.T) {
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, nil) inst := NewModuleInstaller(modulesDir, nil)
_, diags := inst.InstallModules(".", false, hooks) _, diags := inst.InstallModules(context.Background(), ".", false, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
wantCalls := []testInstallHookCall{ wantCalls := []testInstallHookCall{
@ -313,7 +314,7 @@ func TestLoaderInstallModules_registry(t *testing.T) {
hooks := &testInstallHooks{} hooks := &testInstallHooks{}
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, registry.NewClient(nil, nil)) inst := NewModuleInstaller(modulesDir, registry.NewClient(nil, nil))
_, diags := inst.InstallModules(dir, false, hooks) _, diags := inst.InstallModules(context.Background(), dir, false, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
v := version.Must(version.NewVersion("0.0.1")) v := version.Must(version.NewVersion("0.0.1"))
@ -473,7 +474,7 @@ func TestLoaderInstallModules_goGetter(t *testing.T) {
hooks := &testInstallHooks{} hooks := &testInstallHooks{}
modulesDir := filepath.Join(dir, ".terraform/modules") modulesDir := filepath.Join(dir, ".terraform/modules")
inst := NewModuleInstaller(modulesDir, registry.NewClient(nil, nil)) inst := NewModuleInstaller(modulesDir, registry.NewClient(nil, nil))
_, diags := inst.InstallModules(dir, false, hooks) _, diags := inst.InstallModules(context.Background(), dir, false, hooks)
assertNoDiagnostics(t, diags) assertNoDiagnostics(t, diags)
wantCalls := []testInstallHookCall{ wantCalls := []testInstallHookCall{

View File

@ -1,6 +1,7 @@
package initwd package initwd
import ( import (
"context"
"testing" "testing"
"github.com/hashicorp/terraform/internal/configs" "github.com/hashicorp/terraform/internal/configs"
@ -35,7 +36,7 @@ func LoadConfigForTests(t *testing.T, rootDir string) (*configs.Config, *configl
loader, cleanup := configload.NewLoaderForTests(t) loader, cleanup := configload.NewLoaderForTests(t)
inst := NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil)) inst := NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil))
_, moreDiags := inst.InstallModules(rootDir, true, ModuleInstallHooksImpl{}) _, moreDiags := inst.InstallModules(context.Background(), rootDir, true, ModuleInstallHooksImpl{})
diags = diags.Append(moreDiags) diags = diags.Append(moreDiags)
if diags.HasErrors() { if diags.HasErrors() {
cleanup() cleanup()

View File

@ -1,6 +1,7 @@
package refactoring package refactoring
import ( import (
"context"
"strings" "strings"
"testing" "testing"
@ -419,7 +420,7 @@ func loadRefactoringFixture(t *testing.T, dir string) (*configs.Config, instance
defer cleanup() defer cleanup()
inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil)) inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil))
_, instDiags := inst.InstallModules(dir, true, initwd.ModuleInstallHooksImpl{}) _, instDiags := inst.InstallModules(context.Background(), dir, true, initwd.ModuleInstallHooksImpl{})
if instDiags.HasErrors() { if instDiags.HasErrors() {
t.Fatal(instDiags.Err()) t.Fatal(instDiags.Err())
} }

View File

@ -1,6 +1,7 @@
package registry package registry
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -109,7 +110,7 @@ func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, er
} }
// ModuleVersions queries the registry for a module, and returns the available versions. // ModuleVersions queries the registry for a module, and returns the available versions.
func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions, error) { func (c *Client) ModuleVersions(ctx context.Context, module *regsrc.Module) (*response.ModuleVersions, error) {
host, err := module.SvcHost() host, err := module.SvcHost()
if err != nil { if err != nil {
return nil, err return nil, err
@ -133,6 +134,7 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions
if err != nil { if err != nil {
return nil, err return nil, err
} }
req = req.WithContext(ctx)
c.addRequestCreds(host, req.Request) c.addRequestCreds(host, req.Request)
req.Header.Set(xTerraformVersion, tfVersion) req.Header.Set(xTerraformVersion, tfVersion)
@ -182,7 +184,7 @@ func (c *Client) addRequestCreds(host svchost.Hostname, req *http.Request) {
// ModuleLocation find the download location for a specific version module. // ModuleLocation find the download location for a specific version module.
// This returns a string, because the final location may contain special go-getter syntax. // This returns a string, because the final location may contain special go-getter syntax.
func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string, error) { func (c *Client) ModuleLocation(ctx context.Context, module *regsrc.Module, version string) (string, error) {
host, err := module.SvcHost() host, err := module.SvcHost()
if err != nil { if err != nil {
return "", err return "", err
@ -211,6 +213,8 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string,
return "", err return "", err
} }
req = req.WithContext(ctx)
c.addRequestCreds(host, req.Request) c.addRequestCreds(host, req.Request)
req.Header.Set(xTerraformVersion, tfVersion) req.Header.Set(xTerraformVersion, tfVersion)

View File

@ -103,7 +103,7 @@ func TestLookupModuleVersions(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp, err := client.ModuleVersions(modsrc) resp, err := client.ModuleVersions(context.Background(), modsrc)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -143,7 +143,7 @@ func TestInvalidRegistry(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := client.ModuleVersions(modsrc); err == nil { if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
} }
@ -160,11 +160,11 @@ func TestRegistryAuth(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = client.ModuleVersions(mod) _, err = client.ModuleVersions(context.Background(), mod)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = client.ModuleLocation(mod, "1.0.0") _, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -173,11 +173,11 @@ func TestRegistryAuth(t *testing.T) {
client.services.SetCredentialsSource(nil) client.services.SetCredentialsSource(nil)
// both should fail without auth // both should fail without auth
_, err = client.ModuleVersions(mod) _, err = client.ModuleVersions(context.Background(), mod)
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
_, err = client.ModuleLocation(mod, "1.0.0") _, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
@ -195,7 +195,7 @@ func TestLookupModuleLocationRelative(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
got, err := client.ModuleLocation(mod, "0.2.0") got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -224,7 +224,7 @@ func TestAccLookupModuleVersions(t *testing.T) {
} }
s := NewClient(regDisco, nil) s := NewClient(regDisco, nil)
resp, err := s.ModuleVersions(modsrc) resp, err := s.ModuleVersions(context.Background(), modsrc)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -277,7 +277,7 @@ func TestLookupLookupModuleError(t *testing.T) {
return oldCheck(ctx, resp, err) return oldCheck(ctx, resp, err)
} }
_, err = client.ModuleLocation(mod, "0.2.0") _, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
if err == nil { if err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
@ -299,7 +299,7 @@ func TestLookupModuleRetryError(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
resp, err := client.ModuleVersions(modsrc) resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil { if err == nil {
t.Fatal("expected requests to exceed retry", err) t.Fatal("expected requests to exceed retry", err)
} }
@ -328,7 +328,7 @@ func TestLookupModuleNoRetryError(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
resp, err := client.ModuleVersions(modsrc) resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil { if err == nil {
t.Fatal("expected request to fail", err) t.Fatal("expected request to fail", err)
} }
@ -354,7 +354,7 @@ func TestLookupModuleNetworkError(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
resp, err := client.ModuleVersions(modsrc) resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil { if err == nil {
t.Fatal("expected request to fail", err) t.Fatal("expected request to fail", err)
} }

View File

@ -1,6 +1,7 @@
package terraform package terraform
import ( import (
"context"
"flag" "flag"
"io" "io"
"io/ioutil" "io/ioutil"
@ -61,7 +62,7 @@ func testModuleWithSnapshot(t *testing.T, name string) (*configs.Config, *config
// sources only this ultimately just records all of the module paths // sources only this ultimately just records all of the module paths
// in a JSON file so that we can load them below. // in a JSON file so that we can load them below.
inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil)) inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil))
_, instDiags := inst.InstallModules(dir, true, initwd.ModuleInstallHooksImpl{}) _, instDiags := inst.InstallModules(context.Background(), dir, true, initwd.ModuleInstallHooksImpl{})
if instDiags.HasErrors() { if instDiags.HasErrors() {
t.Fatal(instDiags.Err()) t.Fatal(instDiags.Err())
} }
@ -119,7 +120,7 @@ func testModuleInline(t *testing.T, sources map[string]string) *configs.Config {
// sources only this ultimately just records all of the module paths // sources only this ultimately just records all of the module paths
// in a JSON file so that we can load them below. // in a JSON file so that we can load them below.
inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil)) inst := initwd.NewModuleInstaller(loader.ModulesDir(), registry.NewClient(nil, nil))
_, instDiags := inst.InstallModules(cfgPath, true, initwd.ModuleInstallHooksImpl{}) _, instDiags := inst.InstallModules(context.Background(), cfgPath, true, initwd.ModuleInstallHooksImpl{})
if instDiags.HasErrors() { if instDiags.HasErrors() {
t.Fatal(instDiags.Err()) t.Fatal(instDiags.Err())
} }