diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go index f3622bbfc..3bd9307e0 100644 --- a/svchost/disco/disco.go +++ b/svchost/disco/disco.go @@ -19,6 +19,7 @@ import ( cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/terraform/svchost" + "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/terraform" ) @@ -37,12 +38,22 @@ var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during test // for the same information. type Disco struct { hostCache map[svchost.Hostname]Host + credsSrc auth.CredentialsSource } func NewDisco() *Disco { return &Disco{} } +// SetCredentialsSource provides a credentials source that will be used to +// add credentials to outgoing discovery requests, where available. +// +// If this method is never called, no outgoing discovery requests will have +// credentials. +func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) { + d.credsSrc = src +} + // Discover runs the discovery protocol against the given hostname (which must // already have been validated and prepared with svchost.ForComparison) and // returns an object describing the services available at that host. @@ -96,7 +107,6 @@ func (d *Disco) discover(host svchost.Hostname) Host { var header = http.Header{} header.Set("User-Agent", userAgent) - // TODO: look up credentials and add them to the header if we have them req := &http.Request{ Method: "GET", @@ -104,6 +114,17 @@ func (d *Disco) discover(host svchost.Hostname) Host { Header: header, } + if d.credsSrc != nil { + creds, err := d.credsSrc.ForHost(host) + if err == nil { + if creds != nil { + creds.PrepareRequest(req) // alters req to include credentials + } + } else { + log.Printf("[WARNING] Failed to get credentials for %s: %s (ignoring)", host, err) + } + } + log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) ret := Host{ diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go index d514d1f3f..d0492bb12 100644 --- a/svchost/disco/disco_test.go +++ b/svchost/disco/disco_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/hashicorp/terraform/svchost" + "github.com/hashicorp/terraform/svchost/auth" ) func TestMain(m *testing.M) { @@ -89,6 +90,34 @@ func TestDiscover(t *testing.T) { t.Fatalf("wrong result %q; want %q", got, want) } }) + t.Run("with credentials", func(t *testing.T) { + var authHeaderText string + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(`{}`) + authHeaderText = r.Header.Get("Authorization") + w.Header().Add("Content-Type", "application/json") + w.Header().Add("Content-Length", strconv.Itoa(len(resp))) + w.Write(resp) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{ + host: map[string]interface{}{ + "token": "abc123", + }, + })) + d.Discover(host) + if got, want := authHeaderText, "Bearer abc123"; got != want { + t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want) + } + }) t.Run("not JSON", func(t *testing.T) { portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)