svchost/disco: add credentials, if available, to disco requests

Although service discovery metadata is usually not sensitive, a service
host may wish to produce different results depending on the requesting
user, such as if users are migrating between two different implementations
that are both running concurrently for some period.
This commit is contained in:
Martin Atkins 2017-10-18 08:45:52 -07:00
parent fcff4cbc95
commit 83b098344b
2 changed files with 51 additions and 1 deletions

View File

@ -19,6 +19,7 @@ import (
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/terraform/svchost"
"github.com/hashicorp/terraform/svchost/auth"
"github.com/hashicorp/terraform/terraform"
)
@ -37,12 +38,22 @@ var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during test
// for the same information.
type Disco struct {
hostCache map[svchost.Hostname]Host
credsSrc auth.CredentialsSource
}
func NewDisco() *Disco {
return &Disco{}
}
// SetCredentialsSource provides a credentials source that will be used to
// add credentials to outgoing discovery requests, where available.
//
// If this method is never called, no outgoing discovery requests will have
// credentials.
func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
d.credsSrc = src
}
// Discover runs the discovery protocol against the given hostname (which must
// already have been validated and prepared with svchost.ForComparison) and
// returns an object describing the services available at that host.
@ -96,7 +107,6 @@ func (d *Disco) discover(host svchost.Hostname) Host {
var header = http.Header{}
header.Set("User-Agent", userAgent)
// TODO: look up credentials and add them to the header if we have them
req := &http.Request{
Method: "GET",
@ -104,6 +114,17 @@ func (d *Disco) discover(host svchost.Hostname) Host {
Header: header,
}
if d.credsSrc != nil {
creds, err := d.credsSrc.ForHost(host)
if err == nil {
if creds != nil {
creds.PrepareRequest(req) // alters req to include credentials
}
} else {
log.Printf("[WARNING] Failed to get credentials for %s: %s (ignoring)", host, err)
}
}
log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
ret := Host{

View File

@ -10,6 +10,7 @@ import (
"testing"
"github.com/hashicorp/terraform/svchost"
"github.com/hashicorp/terraform/svchost/auth"
)
func TestMain(m *testing.M) {
@ -89,6 +90,34 @@ func TestDiscover(t *testing.T) {
t.Fatalf("wrong result %q; want %q", got, want)
}
})
t.Run("with credentials", func(t *testing.T) {
var authHeaderText string
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{}`)
authHeaderText = r.Header.Get("Authorization")
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := NewDisco()
d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
host: map[string]interface{}{
"token": "abc123",
},
}))
d.Discover(host)
if got, want := authHeaderText, "Bearer abc123"; got != want {
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
}
})
t.Run("not JSON", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)