From 83b098344b8ba521537a2f729d2febb36b0d0f04 Mon Sep 17 00:00:00 2001 From: Martin Atkins Date: Wed, 18 Oct 2017 08:45:52 -0700 Subject: [PATCH] svchost/disco: add credentials, if available, to disco requests Although service discovery metadata is usually not sensitive, a service host may wish to produce different results depending on the requesting user, such as if users are migrating between two different implementations that are both running concurrently for some period. --- svchost/disco/disco.go | 23 ++++++++++++++++++++++- svchost/disco/disco_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go index f3622bbfc..3bd9307e0 100644 --- a/svchost/disco/disco.go +++ b/svchost/disco/disco.go @@ -19,6 +19,7 @@ import ( cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/terraform/svchost" + "github.com/hashicorp/terraform/svchost/auth" "github.com/hashicorp/terraform/terraform" ) @@ -37,12 +38,22 @@ var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during test // for the same information. type Disco struct { hostCache map[svchost.Hostname]Host + credsSrc auth.CredentialsSource } func NewDisco() *Disco { return &Disco{} } +// SetCredentialsSource provides a credentials source that will be used to +// add credentials to outgoing discovery requests, where available. +// +// If this method is never called, no outgoing discovery requests will have +// credentials. +func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) { + d.credsSrc = src +} + // Discover runs the discovery protocol against the given hostname (which must // already have been validated and prepared with svchost.ForComparison) and // returns an object describing the services available at that host. @@ -96,7 +107,6 @@ func (d *Disco) discover(host svchost.Hostname) Host { var header = http.Header{} header.Set("User-Agent", userAgent) - // TODO: look up credentials and add them to the header if we have them req := &http.Request{ Method: "GET", @@ -104,6 +114,17 @@ func (d *Disco) discover(host svchost.Hostname) Host { Header: header, } + if d.credsSrc != nil { + creds, err := d.credsSrc.ForHost(host) + if err == nil { + if creds != nil { + creds.PrepareRequest(req) // alters req to include credentials + } + } else { + log.Printf("[WARNING] Failed to get credentials for %s: %s (ignoring)", host, err) + } + } + log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) ret := Host{ diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go index d514d1f3f..d0492bb12 100644 --- a/svchost/disco/disco_test.go +++ b/svchost/disco/disco_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/hashicorp/terraform/svchost" + "github.com/hashicorp/terraform/svchost/auth" ) func TestMain(m *testing.M) { @@ -89,6 +90,34 @@ func TestDiscover(t *testing.T) { t.Fatalf("wrong result %q; want %q", got, want) } }) + t.Run("with credentials", func(t *testing.T) { + var authHeaderText string + portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { + resp := []byte(`{}`) + authHeaderText = r.Header.Get("Authorization") + w.Header().Add("Content-Type", "application/json") + w.Header().Add("Content-Length", strconv.Itoa(len(resp))) + w.Write(resp) + }) + defer close() + + givenHost := "localhost" + portStr + host, err := svchost.ForComparison(givenHost) + if err != nil { + t.Fatalf("test server hostname is invalid: %s", err) + } + + d := NewDisco() + d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{ + host: map[string]interface{}{ + "token": "abc123", + }, + })) + d.Discover(host) + if got, want := authHeaderText, "Bearer abc123"; got != want { + t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want) + } + }) t.Run("not JSON", func(t *testing.T) { portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) { resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)