Add a `CredentialsForHost` method to disco.Disco

By adding this method you now only have to pass a `*disco.Disco` object around in order to do discovery and use any configured credentials for the discovered hosts.

Of course you can also still pass around both a `*disco.Disco` and a `auth.CredentialsSource` object if there is a need or a reason for that!
This commit is contained in:
Sander van Harmelen 2018-07-05 21:28:29 +02:00
parent 495d1ea350
commit 179b32d426
17 changed files with 81 additions and 95 deletions

View File

@ -117,7 +117,7 @@ func testModule(t *testing.T, name string) *module.Tree {
t.Fatalf("err: %s", err)
}
s := module.NewStorage(tempDir(t), nil, nil)
s := module.NewStorage(tempDir(t), nil)
s.Mode = module.GetModeGet
if err := mod.Load(s); err != nil {
t.Fatalf("err: %s", err)

View File

@ -129,7 +129,7 @@ func (c *InitCommand) Run(args []string) int {
)))
header = true
s := module.NewStorage("", c.Services, c.Credentials)
s := module.NewStorage("", c.Services)
if err := s.GetModule(path, src); err != nil {
c.Ui.Error(fmt.Sprintf("Error copying source module: %s", err))
return 1

View File

@ -25,7 +25,6 @@ import (
"github.com/hashicorp/terraform/helper/experiment"
"github.com/hashicorp/terraform/helper/variables"
"github.com/hashicorp/terraform/helper/wrappedstreams"
"github.com/hashicorp/terraform/svchost/auth"
"github.com/hashicorp/terraform/svchost/disco"
"github.com/hashicorp/terraform/terraform"
"github.com/hashicorp/terraform/tfdiags"
@ -51,10 +50,6 @@ type Meta struct {
// "terraform-native' services running at a specific user-facing hostname.
Services *disco.Disco
// Credentials provides access to credentials for "terraform-native"
// services, which are accessed by a service hostname.
Credentials auth.CredentialsSource
// RunningInAutomation indicates that commands are being run by an
// automated system rather than directly at a command prompt.
//
@ -410,7 +405,7 @@ func (m *Meta) flagSet(n string) *flag.FlagSet {
// moduleStorage returns the module.Storage implementation used to store
// modules for commands.
func (m *Meta) moduleStorage(root string, mode module.GetMode) *module.Storage {
s := module.NewStorage(filepath.Join(root, "modules"), m.Services, m.Credentials)
s := module.NewStorage(filepath.Join(root, "modules"), m.Services)
s.Ui = m.Ui
s.Mode = mode
return s

View File

@ -30,15 +30,12 @@ const (
OutputPrefix = "o:"
)
func initCommands(config *Config) {
func initCommands(config *Config, services *disco.Disco) {
var inAutomation bool
if v := os.Getenv(runningInAutomationEnvName); v != "" {
inAutomation = true
}
credsSrc := credentialsSource(config)
services := disco.NewDisco()
services.SetCredentialsSource(credsSrc)
for userHost, hostConfig := range config.Hosts {
host, err := svchost.ForComparison(userHost)
if err != nil {
@ -57,8 +54,7 @@ func initCommands(config *Config) {
PluginOverrides: &PluginOverrides,
Ui: Ui,
Services: services,
Credentials: credsSrc,
Services: services,
RunningInAutomation: inAutomation,
PluginCacheDir: config.PluginCacheDir,

View File

@ -44,5 +44,5 @@ func testConfig(t *testing.T, n string) *config.Config {
func testStorage(t *testing.T, d *disco.Disco) *Storage {
t.Helper()
return NewStorage(tempDir(t), d, nil)
return NewStorage(tempDir(t), d)
}

View File

@ -11,7 +11,6 @@ import (
getter "github.com/hashicorp/go-getter"
"github.com/hashicorp/terraform/registry"
"github.com/hashicorp/terraform/registry/regsrc"
"github.com/hashicorp/terraform/svchost/auth"
"github.com/hashicorp/terraform/svchost/disco"
"github.com/mitchellh/cli"
)
@ -64,14 +63,10 @@ type Storage struct {
// StorageDir is the full path to the directory where all modules will be
// stored.
StorageDir string
// Services is a required *disco.Disco, which may have services and
// credentials pre-loaded.
Services *disco.Disco
// Creds optionally provides credentials for communicating with service
// providers.
Creds auth.CredentialsSource
// Ui is an optional cli.Ui for user output
Ui cli.Ui
// Mode is the GetMode that will be used for various operations.
Mode GetMode
@ -79,8 +74,8 @@ type Storage struct {
}
// NewStorage returns a new initialized Storage object.
func NewStorage(dir string, services *disco.Disco, creds auth.CredentialsSource) *Storage {
regClient := registry.NewClient(services, creds, nil)
func NewStorage(dir string, services *disco.Disco) *Storage {
regClient := registry.NewClient(services, nil)
return &Storage{
StorageDir: dir,

View File

@ -22,7 +22,7 @@ func TestGetModule(t *testing.T) {
t.Fatal(err)
}
defer os.RemoveAll(td)
storage := NewStorage(td, disco, nil)
storage := NewStorage(td, disco)
// this module exists in a test fixture, and is known by the test.Registry
// relative to our cwd.
@ -139,7 +139,7 @@ func TestAccRegistryDiscover(t *testing.T) {
t.Fatal(err)
}
s := NewStorage("/tmp", nil, nil)
s := NewStorage("/tmp", nil)
loc, err := s.registry.Location(module, "")
if err != nil {
t.Fatal(err)

View File

@ -5,7 +5,6 @@ import (
"github.com/hashicorp/terraform/configs"
"github.com/hashicorp/terraform/registry"
"github.com/hashicorp/terraform/svchost/auth"
"github.com/hashicorp/terraform/svchost/disco"
"github.com/spf13/afero"
)
@ -39,10 +38,6 @@ type Config struct {
// not supported, which should be true only in specialized circumstances
// such as in tests.
Services *disco.Disco
// Creds is a credentials store for communicating with remote module
// registry endpoints. If this is nil then no credentials will be used.
Creds auth.CredentialsSource
}
// NewLoader creates and returns a loader that reads configuration from the
@ -54,7 +49,7 @@ type Config struct {
func NewLoader(config *Config) (*Loader, error) {
fs := afero.NewOsFs()
parser := configs.NewParser(fs)
reg := registry.NewClient(config.Services, config.Creds, nil)
reg := registry.NewClient(config.Services, nil)
ret := &Loader{
parser: parser,
@ -63,7 +58,6 @@ func NewLoader(config *Config) (*Loader, error) {
CanInstall: true,
Dir: config.ModulesDir,
Services: config.Services,
Creds: config.Creds,
Registry: reg,
},
}

View File

@ -2,7 +2,6 @@ package configload
import (
"github.com/hashicorp/terraform/registry"
"github.com/hashicorp/terraform/svchost/auth"
"github.com/hashicorp/terraform/svchost/disco"
"github.com/spf13/afero"
)
@ -25,9 +24,6 @@ type moduleMgr struct {
// cached discovery information.
Services *disco.Disco
// Creds provides optional credentials for communicating with service hosts.
Creds auth.CredentialsSource
// Registry is a client for the module registry protocol, which is used
// when a module is requested from a registry source.
Registry *registry.Client

View File

@ -16,6 +16,7 @@ import (
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/terraform/command/format"
"github.com/hashicorp/terraform/helper/logging"
"github.com/hashicorp/terraform/svchost/disco"
"github.com/hashicorp/terraform/terraform"
"github.com/mattn/go-colorable"
"github.com/mattn/go-shellwords"
@ -144,7 +145,9 @@ func wrappedMain() int {
// In tests, Commands may already be set to provide mock commands
if Commands == nil {
initCommands(config)
credsSrc := credentialsSource(config)
services := disco.NewWithCredentialsSource(credsSrc)
initCommands(config, services)
}
// Run checkpoint

View File

@ -15,7 +15,6 @@ import (
"github.com/hashicorp/terraform/registry/regsrc"
"github.com/hashicorp/terraform/registry/response"
"github.com/hashicorp/terraform/svchost"
"github.com/hashicorp/terraform/svchost/auth"
"github.com/hashicorp/terraform/svchost/disco"
"github.com/hashicorp/terraform/version"
)
@ -37,20 +36,14 @@ type Client struct {
// services is a required *disco.Disco, which may have services and
// credentials pre-loaded.
services *disco.Disco
// Creds optionally provides credentials for communicating with service
// providers.
creds auth.CredentialsSource
}
// NewClient returns a new initialized registry client.
func NewClient(services *disco.Disco, creds auth.CredentialsSource, client *http.Client) *Client {
func NewClient(services *disco.Disco, client *http.Client) *Client {
if services == nil {
services = disco.NewDisco()
services = disco.New()
}
services.SetCredentialsSource(creds)
if client == nil {
client = httpclient.New()
client.Timeout = requestTimeout
@ -61,7 +54,6 @@ func NewClient(services *disco.Disco, creds auth.CredentialsSource, client *http
return &Client{
client: client,
services: services,
creds: creds,
}
}
@ -138,11 +130,7 @@ func (c *Client) Versions(module *regsrc.Module) (*response.ModuleVersions, erro
}
func (c *Client) addRequestCreds(host svchost.Hostname, req *http.Request) {
if c.creds == nil {
return
}
creds, err := c.creds.ForHost(host)
creds, err := c.services.CredentialsForHost(host)
if err != nil {
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
return

View File

@ -15,7 +15,7 @@ func TestLookupModuleVersions(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(test.Disco(server), nil, nil)
client := NewClient(test.Disco(server), nil)
// test with and without a hostname
for _, src := range []string{
@ -59,7 +59,7 @@ func TestInvalidRegistry(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(test.Disco(server), nil, nil)
client := NewClient(test.Disco(server), nil)
src := "non-existent.localhost.localdomain/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
@ -76,7 +76,7 @@ func TestRegistryAuth(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(test.Disco(server), nil, nil)
client := NewClient(test.Disco(server), nil)
src := "private/name/provider"
mod, err := regsrc.ParseModuleSource(src)
@ -84,6 +84,18 @@ func TestRegistryAuth(t *testing.T) {
t.Fatal(err)
}
_, err = client.Versions(mod)
if err != nil {
t.Fatal(err)
}
_, err = client.Location(mod, "1.0.0")
if err != nil {
t.Fatal(err)
}
// Also test without a credentials source
client.services.SetCredentialsSource(nil)
// both should fail without auth
_, err = client.Versions(mod)
if err == nil {
@ -93,24 +105,13 @@ func TestRegistryAuth(t *testing.T) {
if err == nil {
t.Fatal("expected error")
}
client = NewClient(test.Disco(server), test.Credentials, nil)
_, err = client.Versions(mod)
if err != nil {
t.Fatal(err)
}
_, err = client.Location(mod, "1.0.0")
if err != nil {
t.Fatal(err)
}
}
func TestLookupModuleLocationRelative(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(test.Disco(server), nil, nil)
client := NewClient(test.Disco(server), nil)
src := "relative/foo/bar"
mod, err := regsrc.ParseModuleSource(src)
@ -133,7 +134,7 @@ func TestAccLookupModuleVersions(t *testing.T) {
if os.Getenv("TF_ACC") == "" {
t.Skip()
}
regDisco := disco.NewDisco()
regDisco := disco.New()
// test with and without a hostname
for _, src := range []string{
@ -145,7 +146,7 @@ func TestAccLookupModuleVersions(t *testing.T) {
t.Fatal(err)
}
s := NewClient(regDisco, nil, nil)
s := NewClient(regDisco, nil)
resp, err := s.Versions(modsrc)
if err != nil {
t.Fatal(err)
@ -179,7 +180,7 @@ func TestLookupLookupModuleError(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(test.Disco(server), nil, nil)
client := NewClient(test.Disco(server), nil)
// this should not be found in teh registry
src := "bad/local/path"

View File

@ -27,7 +27,7 @@ func Disco(s *httptest.Server) *disco.Disco {
// TODO: add specific tests to enumerate both possibilities.
"modules.v1": fmt.Sprintf("%s/v1/modules", s.URL),
}
d := disco.NewDisco()
d := disco.NewWithCredentialsSource(credsSrc)
d.ForceHostServices(svchost.Hostname("registry.terraform.io"), services)
d.ForceHostServices(svchost.Hostname("localhost"), services)
@ -48,8 +48,8 @@ const (
)
var (
regHost = svchost.Hostname(regsrc.PublicRegistryHost.Normalized())
Credentials = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
regHost = svchost.Hostname(regsrc.PublicRegistryHost.Normalized())
credsSrc = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
regHost: {"token": testCred},
})
)

View File

@ -42,6 +42,9 @@ type HostCredentials interface {
// receiving credentials. The usual behavior of this method is to
// add some sort of Authorization header to the request.
PrepareRequest(req *http.Request)
// Token returns the authentication token.
Token() string
}
// ForHost iterates over the contained CredentialsSource objects and

View File

@ -18,3 +18,8 @@ func (tc HostCredentialsToken) PrepareRequest(req *http.Request) {
}
req.Header.Set("Authorization", "Bearer "+string(tc))
}
// Token returns the authentication token.
func (tc HostCredentialsToken) Token() string {
return string(tc)
}

View File

@ -42,9 +42,15 @@ type Disco struct {
Transport http.RoundTripper
}
// NewDisco returns a new initialized Disco object.
func NewDisco() *Disco {
return &Disco{}
// New returns a new initialized discovery object.
func New() *Disco {
return NewWithCredentialsSource(nil)
}
// NewWithCredentialsSource returns a new discovery object initialized with
// the given credentials source.
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
return &Disco{credsSrc: credsSrc}
}
// SetCredentialsSource provides a credentials source that will be used to
@ -56,6 +62,15 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
d.credsSrc = src
}
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
// credentials available for the host, and a nil HostCredentials if it does not.
func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) {
if d.credsSrc == nil {
return nil, nil
}
return d.credsSrc.ForHost(host)
}
// ForceHostServices provides a pre-defined set of services for a given
// host, which prevents the receiver from attempting network-based discovery
// for the given host. Instead, the given services map will be returned
@ -145,15 +160,10 @@ func (d *Disco) discover(host svchost.Hostname) Host {
URL: discoURL,
}
if d.credsSrc != nil {
creds, err := d.credsSrc.ForHost(host)
if err == nil {
if creds != nil {
creds.PrepareRequest(req) // alters req to include credentials
}
} else {
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
}
if creds, err := d.CredentialsForHost(host); err != nil {
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
} else if creds != nil {
creds.PrepareRequest(req) // alters req to include credentials
}
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)

View File

@ -45,7 +45,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
gotURL := discovered.ServiceURL("thingy.v1")
if gotURL == nil {
@ -80,7 +80,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
gotURL := discovered.ServiceURL("wotsit.v2")
if gotURL == nil {
@ -107,7 +107,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
host: map[string]interface{}{
"token": "abc123",
@ -124,7 +124,7 @@ func TestDiscover(t *testing.T) {
"wotsit.v2": "/foo",
}
d := NewDisco()
d := New()
d.ForceHostServices(svchost.Hostname("example.com"), forced)
givenHost := "example.com"
@ -167,7 +167,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
// result should be empty, which we can verify only by reaching into
@ -190,7 +190,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
// result should be empty, which we can verify only by reaching into
@ -217,7 +217,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
if discovered.services == nil {
@ -236,7 +236,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
// result should be empty, which we can verify only by reaching into
@ -267,7 +267,7 @@ func TestDiscover(t *testing.T) {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d := New()
discovered := d.Discover(host)
gotURL := discovered.ServiceURL("thingy.v1")