diff --git a/backend/remote/backend.go b/backend/remote/backend.go index 17bb4c912..97b2a148e 100644 --- a/backend/remote/backend.go +++ b/backend/remote/backend.go @@ -302,9 +302,9 @@ func (b *Remote) discover(hostname string) (*url.URL, error) { if err != nil { return nil, err } - service := b.services.DiscoverServiceURL(host, serviceID) - if service == nil { - return nil, fmt.Errorf("host %s does not provide a remote backend API", host) + service, err := b.services.DiscoverServiceURL(host, serviceID) + if err != nil { + return nil, err } return service, nil } diff --git a/backend/remote/backend_test.go b/backend/remote/backend_test.go index db476a231..8aa888cc0 100644 --- a/backend/remote/backend_test.go +++ b/backend/remote/backend_test.go @@ -56,7 +56,7 @@ func TestRemote_config(t *testing.T) { "prefix": cty.NullVal(cty.String), }), }), - confErr: "Host nonexisting.local does not provide a remote backend API", + confErr: "Failed to request discovery document", }, "with_a_name": { config: cty.ObjectVal(map[string]cty.Value{ @@ -112,8 +112,8 @@ func TestRemote_config(t *testing.T) { // Validate valDiags := b.ValidateConfig(tc.config) - if (valDiags.Err() == nil && tc.valErr != "") || - (valDiags.Err() != nil && !strings.Contains(valDiags.Err().Error(), tc.valErr)) { + if (valDiags.Err() != nil || tc.valErr != "") && + (valDiags.Err() == nil || !strings.Contains(valDiags.Err().Error(), tc.valErr)) { t.Fatalf("%s: unexpected validation result: %v", name, valDiags.Err()) } diff --git a/registry/client.go b/registry/client.go index cdd33dc9e..0b90790d4 100644 --- a/registry/client.go +++ b/registry/client.go @@ -59,15 +59,15 @@ func NewClient(services *disco.Disco, client *http.Client) *Client { } // Discover queries the host, and returns the url for the registry. -func (c *Client) Discover(host svchost.Hostname, serviceID string) *url.URL { - service := c.services.DiscoverServiceURL(host, serviceID) - if service == nil { - return nil +func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) { + service, err := c.services.DiscoverServiceURL(host, serviceID) + if err != nil { + return nil, err } if !strings.HasSuffix(service.Path, "/") { service.Path += "/" } - return service + return service, nil } // ModuleVersions queries the registry for a module, and returns the available versions. @@ -77,9 +77,9 @@ func (c *Client) ModuleVersions(module *regsrc.Module) (*response.ModuleVersions return nil, err } - service := c.Discover(host, modulesServiceID) - if service == nil { - return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} + service, err := c.Discover(host, modulesServiceID) + if err != nil { + return nil, err } p, err := url.Parse(path.Join(module.Module(), "versions")) @@ -150,9 +150,9 @@ func (c *Client) ModuleLocation(module *regsrc.Module, version string) (string, return "", err } - service := c.Discover(host, modulesServiceID) - if service == nil { - return "", &errServiceNotProvided{host: host.ForDisplay(), service: "modules"} + service, err := c.Discover(host, modulesServiceID) + if err != nil { + return "", err } var p *url.URL @@ -234,9 +234,9 @@ func (c *Client) TerraformProviderVersions(provider *regsrc.TerraformProvider) ( return nil, err } - service := c.Discover(host, providersServiceID) - if service == nil { - return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} + service, err := c.Discover(host, providersServiceID) + if err != nil { + return nil, err } p, err := url.Parse(path.Join(provider.TerraformProvider(), "versions")) @@ -288,9 +288,9 @@ func (c *Client) TerraformProviderLocation(provider *regsrc.TerraformProvider, v return nil, err } - service := c.Discover(host, providersServiceID) - if service == nil { - return nil, &errServiceNotProvided{host: host.ForDisplay(), service: "providers"} + service, err := c.Discover(host, providersServiceID) + if err != nil { + return nil, err } p, err := url.Parse(path.Join( diff --git a/registry/errors.go b/registry/errors.go index 6d6dc95d4..cdde48221 100644 --- a/registry/errors.go +++ b/registry/errors.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/hashicorp/terraform/registry/regsrc" + "github.com/hashicorp/terraform/svchost/disco" ) type errModuleNotFound struct { @@ -42,15 +43,6 @@ func IsProviderNotFound(err error) bool { // error. This allows callers to recognize this particular error condition // as distinct from operational errors such as poor network connectivity. func IsServiceNotProvided(err error) bool { - _, ok := err.(*errServiceNotProvided) + _, ok := err.(*disco.ErrServiceNotProvided) return ok } - -type errServiceNotProvided struct { - host string - service string -} - -func (e *errServiceNotProvided) Error() string { - return fmt.Sprintf("host %s does not provide %s", e.host, e.service) -} diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go index 7fc49da9c..42a2dc4cd 100644 --- a/svchost/disco/disco.go +++ b/svchost/disco/disco.go @@ -8,6 +8,7 @@ package disco import ( "encoding/json" "errors" + "fmt" "io" "io/ioutil" "log" @@ -22,19 +23,27 @@ import ( ) const ( - discoPath = "/.well-known/terraform.json" - maxRedirects = 3 // arbitrary-but-small number to prevent runaway redirect loops - discoTimeout = 11 * time.Second // arbitrary-but-small time limit to prevent UI "hangs" during discovery - maxDiscoDocBytes = 1 * 1024 * 1024 // 1MB - to prevent abusive services from using loads of our memory + // Fixed path to the discovery manifest. + discoPath = "/.well-known/terraform.json" + + // Arbitrary-but-small number to prevent runaway redirect loops. + maxRedirects = 3 + + // Arbitrary-but-small time limit to prevent UI "hangs" during discovery. + discoTimeout = 11 * time.Second + + // 1MB - to prevent abusive services from using loads of our memory. + maxDiscoDocBytes = 1 * 1024 * 1024 ) -var httpTransport = cleanhttp.DefaultPooledTransport() // overridden during tests, to skip TLS verification +// httpTransport is overridden during tests, to skip TLS verification. +var httpTransport = cleanhttp.DefaultPooledTransport() // Disco is the main type in this package, which allows discovery on given // hostnames and caches the results by hostname to avoid repeated requests // for the same information. type Disco struct { - hostCache map[svchost.Hostname]Host + hostCache map[svchost.Hostname]*Host credsSrc auth.CredentialsSource // Transport is a custom http.RoundTripper to use. @@ -50,7 +59,10 @@ func New() *Disco { // NewWithCredentialsSource returns a new discovery object initialized with // the given credentials source. func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco { - return &Disco{credsSrc: credsSrc} + return &Disco{ + hostCache: make(map[svchost.Hostname]*Host), + credsSrc: credsSrc, + } } // SetCredentialsSource provides a credentials source that will be used to @@ -64,11 +76,11 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) { // CredentialsForHost returns a non-nil HostCredentials if the embedded source has // credentials available for the host, and a nil HostCredentials if it does not. -func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) { +func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) { if d.credsSrc == nil { return nil, nil } - return d.credsSrc.ForHost(host) + return d.credsSrc.ForHost(hostname) } // ForceHostServices provides a pre-defined set of services for a given @@ -81,19 +93,17 @@ func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, // discovery, yielding the same results as if the given map were published // at the host's default discovery URL, though using absolute URLs is strongly // recommended to make the configured behavior more explicit. -func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]interface{}) { - if d.hostCache == nil { - d.hostCache = map[svchost.Hostname]Host{} - } +func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) { if services == nil { services = map[string]interface{}{} } - d.hostCache[host] = Host{ + d.hostCache[hostname] = &Host{ discoURL: &url.URL{ Scheme: "https", - Host: string(host), + Host: string(hostname), Path: discoPath, }, + hostname: hostname.ForDisplay(), services: services, } } @@ -104,36 +114,40 @@ func (d *Disco) ForceHostServices(host svchost.Hostname, services map[string]int // // If a given hostname supports no Terraform services at all, a non-nil but // empty Host object is returned. When giving feedback to the end user about -// such situations, we say e.g. "the host doesn't provide a module -// registry", regardless of whether that is due to that service specifically -// being absent or due to the host not providing Terraform services at all, -// since we don't wish to expose the detail of whole-host discovery to an -// end-user. -func (d *Disco) Discover(host svchost.Hostname) Host { - if d.hostCache == nil { - d.hostCache = map[svchost.Hostname]Host{} - } - if cache, cached := d.hostCache[host]; cached { - return cache +// such situations, we say "host does not provide a service", +// regardless of whether that is due to that service specifically being absent +// or due to the host not providing Terraform services at all, since we don't +// wish to expose the detail of whole-host discovery to an end-user. +func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) { + if host, cached := d.hostCache[hostname]; cached { + return host, nil } - ret := d.discover(host) - d.hostCache[host] = ret - return ret + host, err := d.discover(hostname) + if err != nil { + return nil, err + } + d.hostCache[hostname] = host + + return host, nil } // DiscoverServiceURL is a convenience wrapper for discovery on a given // hostname and then looking up a particular service in the result. -func (d *Disco) DiscoverServiceURL(host svchost.Hostname, serviceID string) *url.URL { - return d.Discover(host).ServiceURL(serviceID) +func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) { + host, err := d.Discover(hostname) + if err != nil { + return nil, err + } + return host.ServiceURL(serviceID) } // discover implements the actual discovery process, with its result cached // by the public-facing Discover method. -func (d *Disco) discover(host svchost.Hostname) Host { +func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) { discoURL := &url.URL{ Scheme: "https", - Host: host.String(), + Host: hostname.String(), Path: discoPath, } @@ -149,7 +163,7 @@ func (d *Disco) discover(host svchost.Hostname) Host { CheckRedirect: func(req *http.Request, via []*http.Request) error { log.Printf("[DEBUG] Service discovery redirected to %s", req.URL) if len(via) > maxRedirects { - return errors.New("too many redirects") // (this error message will never actually be seen) + return errors.New("too many redirects") // this error will never actually be seen } return nil }, @@ -160,82 +174,84 @@ func (d *Disco) discover(host svchost.Hostname) Host { URL: discoURL, } - if creds, err := d.CredentialsForHost(host); err != nil { - log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err) - } else if creds != nil { - creds.PrepareRequest(req) // alters req to include credentials + creds, err := d.CredentialsForHost(hostname) + if err != nil { + log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err) + } + if creds != nil { + // Update the request to include credentials. + creds.PrepareRequest(req) } - log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL) - - ret := Host{ - discoURL: discoURL, - } + log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL) resp, err := client.Do(req) if err != nil { - log.Printf("[WARN] Failed to request discovery document: %s", err) - return ret // empty + return nil, fmt.Errorf("Failed to request discovery document: %v", err) } defer resp.Body.Close() - if resp.StatusCode != 200 { - log.Printf("[WARN] Failed to request discovery document: %s", resp.Status) - return ret // empty + host := &Host{ + // Use the discovery URL from resp.Request in + // case the client followed any redirects. + discoURL: resp.Request.URL, + hostname: hostname.ForDisplay(), } - // If the client followed any redirects, we will have a new URL to use - // as our base for relative resolution. - ret.discoURL = resp.Request.URL + // Return the host without any services. + if resp.StatusCode == 404 { + return host, nil + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status) + } contentType := resp.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { - log.Printf("[WARN] Discovery URL has malformed Content-Type %q", contentType) - return ret // empty + return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType) } if mediaType != "application/json" { - log.Printf("[DEBUG] Discovery URL returned Content-Type %q, rather than application/json", mediaType) - return ret // empty + return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType) } - // (this doesn't catch chunked encoding, because ContentLength is -1 in that case...) + // This doesn't catch chunked encoding, because ContentLength is -1 in that case. if resp.ContentLength > maxDiscoDocBytes { // Size limit here is not a contractual requirement and so we may // adjust it over time if we find a different limit is warranted. - log.Printf("[WARN] Discovery doc response is too large (got %d bytes; limit %d)", resp.ContentLength, maxDiscoDocBytes) - return ret // empty + return nil, fmt.Errorf( + "Discovery doc response is too large (got %d bytes; limit %d)", + resp.ContentLength, maxDiscoDocBytes, + ) } - // If the response is using chunked encoding then we can't predict - // its size, but we'll at least prevent reading the entire thing into - // memory. + // If the response is using chunked encoding then we can't predict its + // size, but we'll at least prevent reading the entire thing into memory. lr := io.LimitReader(resp.Body, maxDiscoDocBytes) servicesBytes, err := ioutil.ReadAll(lr) if err != nil { - log.Printf("[WARN] Error reading discovery document body: %s", err) - return ret // empty + return nil, fmt.Errorf("Error reading discovery document body: %v", err) } var services map[string]interface{} err = json.Unmarshal(servicesBytes, &services) if err != nil { - log.Printf("[WARN] Failed to decode discovery document as a JSON object: %s", err) - return ret // empty + return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err) } + host.services = services - ret.services = services - return ret + return host, nil } // Forget invalidates any cached record of the given hostname. If the host // has no cache entry then this is a no-op. -func (d *Disco) Forget(host svchost.Hostname) { - delete(d.hostCache, host) +func (d *Disco) Forget(hostname svchost.Hostname) { + delete(d.hostCache, hostname) } // ForgetAll is like Forget, but for all of the hostnames that have cache entries. func (d *Disco) ForgetAll() { - d.hostCache = nil + d.hostCache = make(map[svchost.Hostname]*Host) } diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go index c8bc16c45..95204e6f7 100644 --- a/svchost/disco/disco_test.go +++ b/svchost/disco/disco_test.go @@ -46,8 +46,15 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) - gotURL := discovered.ServiceURL("thingy.v1") + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } + + gotURL, err := discovered.ServiceURL("thingy.v1") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for thingy.v1") } @@ -81,8 +88,15 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) - gotURL := discovered.ServiceURL("wotsit.v2") + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } + + gotURL, err := discovered.ServiceURL("wotsit.v2") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for wotsit.v2") } @@ -133,9 +147,15 @@ func TestDiscover(t *testing.T) { t.Fatalf("test server hostname is invalid: %s", err) } - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } { - gotURL := discovered.ServiceURL("thingy.v1") + gotURL, err := discovered.ServiceURL("thingy.v1") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for thingy.v1") } @@ -144,7 +164,10 @@ func TestDiscover(t *testing.T) { } } { - gotURL := discovered.ServiceURL("wotsit.v2") + gotURL, err := discovered.ServiceURL("wotsit.v2") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for wotsit.v2") } @@ -168,12 +191,14 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err == nil { + t.Fatalf("expected a discovery error") + } - // result should be empty, which we can verify only by reaching into - // its internals. - if discovered.services != nil { - t.Errorf("response not empty; should be") + // Returned discovered should be nil. + if discovered != nil { + t.Errorf("discovered not nil; should be") } }) t.Run("malformed JSON", func(t *testing.T) { @@ -191,12 +216,14 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err == nil { + t.Fatalf("expected a discovery error") + } - // result should be empty, which we can verify only by reaching into - // its internals. - if discovered.services != nil { - t.Errorf("response not empty; should be") + // Returned discovered should be nil. + if discovered != nil { + t.Errorf("discovered not nil; should be") } }) t.Run("JSON with redundant charset", func(t *testing.T) { @@ -218,7 +245,10 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } if discovered.services == nil { t.Errorf("response is empty; shouldn't be") @@ -237,12 +267,14 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } - // result should be empty, which we can verify only by reaching into - // its internals. + // Returned discovered.services should be nil (empty). if discovered.services != nil { - t.Errorf("response not empty; should be") + t.Errorf("discovered.services not nil (empty); should be") } }) t.Run("redirect", func(t *testing.T) { @@ -268,9 +300,15 @@ func TestDiscover(t *testing.T) { } d := New() - discovered := d.Discover(host) + discovered, err := d.Discover(host) + if err != nil { + t.Fatalf("unexpected discovery error: %s", err) + } - gotURL := discovered.ServiceURL("thingy.v1") + gotURL, err := discovered.ServiceURL("thingy.v1") + if err != nil { + t.Fatalf("unexpected service URL error: %s", err) + } if gotURL == nil { t.Fatalf("found no URL for thingy.v1") } diff --git a/svchost/disco/host.go b/svchost/disco/host.go index faf58220a..55cc10813 100644 --- a/svchost/disco/host.go +++ b/svchost/disco/host.go @@ -1,51 +1,95 @@ package disco import ( + "fmt" "net/url" + "strings" ) +// Host represents a service discovered host. type Host struct { discoURL *url.URL + hostname string services map[string]interface{} } +// ErrServiceNotProvided is returned when the service is not provided. +type ErrServiceNotProvided struct { + hostname string + service string +} + +// Error returns a customized error message. +func (e *ErrServiceNotProvided) Error() string { + return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service) +} + +// ErrVersionNotSupported is returned when the version is not supported. +type ErrVersionNotSupported struct { + hostname string + service string + version string +} + +// Error returns a customized error message. +func (e *ErrVersionNotSupported) Error() string { + return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version) +} + // ServiceURL returns the URL associated with the given service identifier, // which should be of the form "servicename.vN". // -// A non-nil result is always an absolute URL with a scheme of either https -// or http. -// -// If the requested service is not supported by the host, this method returns -// a nil URL. -// -// If the discovery document entry for the given service is invalid (not a URL), -// it is treated as absent, also returning a nil URL. -func (h Host) ServiceURL(id string) *url.URL { - if h.services == nil { - return nil // no services supported for an empty Host +// A non-nil result is always an absolute URL with a scheme of either HTTPS +// or HTTP. +func (h *Host) ServiceURL(id string) (*url.URL, error) { + parts := strings.SplitN(id, ".", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id) + } + service, version := parts[0], parts[1] + + // No services supported for an empty Host. + if h == nil || h.services == nil { + return nil, &ErrServiceNotProvided{hostname: "", service: service} } urlStr, ok := h.services[id].(string) if !ok { - return nil + // See if we have a matching service as that would indicate + // the service is supported, but not the requested version. + for serviceID := range h.services { + if strings.HasPrefix(serviceID, service) { + return nil, &ErrVersionNotSupported{ + hostname: h.hostname, + service: service, + version: version, + } + } + } + + // No discovered services match the requested service ID. + return nil, &ErrServiceNotProvided{hostname: h.hostname, service: service} } - ret, err := url.Parse(urlStr) + u, err := url.Parse(urlStr) if err != nil { - return nil + return nil, fmt.Errorf("Failed to parse service URL: %v", err) } - if !ret.IsAbs() { - ret = h.discoURL.ResolveReference(ret) // make absolute using our discovery doc URL - } - if ret.Scheme != "https" && ret.Scheme != "http" { - return nil - } - if ret.User != nil { - // embedded username/password information is not permitted; credentials - // are handled out of band. - return nil - } - ret.Fragment = "" // fragment part is irrelevant, since we're not a browser - return h.discoURL.ResolveReference(ret) + // Make relative URLs absolute using our discovery URL. + if !u.IsAbs() { + u = h.discoURL.ResolveReference(u) + } + + if u.Scheme != "https" && u.Scheme != "http" { + return nil, fmt.Errorf("Service URL is using an unsupported scheme: %s", u.Scheme) + } + if u.User != nil { + return nil, fmt.Errorf("Embedded username/password information is not permitted") + } + + // Fragment part is irrelevant, since we're not a browser. + u.Fragment = "" + + return h.discoURL.ResolveReference(u), nil } diff --git a/svchost/disco/host_test.go b/svchost/disco/host_test.go index 8a9fe4c76..c6a1d8eaf 100644 --- a/svchost/disco/host_test.go +++ b/svchost/disco/host_test.go @@ -2,6 +2,7 @@ package disco import ( "net/url" + "strings" "testing" ) @@ -9,6 +10,7 @@ func TestHostServiceURL(t *testing.T) { baseURL, _ := url.Parse("https://example.com/disco/foo.json") host := Host{ discoURL: baseURL, + hostname: "test-server", services: map[string]interface{}{ "absolute.v1": "http://example.net/foo/bar", "absolutewithport.v1": "http://example.net:8080/foo/bar", @@ -24,22 +26,28 @@ func TestHostServiceURL(t *testing.T) { tests := []struct { ID string - Want string + want string + err string }{ - {"absolute.v1", "http://example.net/foo/bar"}, - {"absolutewithport.v1", "http://example.net:8080/foo/bar"}, - {"relative.v1", "https://example.com/disco/stu/"}, - {"rootrelative.v1", "https://example.com/baz"}, - {"protorelative.v1", "https://example.net/"}, - {"withfragment.v1", "http://example.org/"}, - {"querystring.v1", "https://example.net/baz?foo=bar"}, // most callers will disregard query string - {"nothttp.v1", ""}, - {"invalid.v1", ""}, + {"absolute.v1", "http://example.net/foo/bar", ""}, + {"absolutewithport.v1", "http://example.net:8080/foo/bar", ""}, + {"relative.v1", "https://example.com/disco/stu/", ""}, + {"rootrelative.v1", "https://example.com/baz", ""}, + {"protorelative.v1", "https://example.net/", ""}, + {"withfragment.v1", "http://example.org/", ""}, + {"querystring.v1", "https://example.net/baz?foo=bar", ""}, + {"nothttp.v1", "", "unsupported scheme"}, + {"invalid.v1", "", "Failed to parse service URL"}, } for _, test := range tests { t.Run(test.ID, func(t *testing.T) { - url := host.ServiceURL(test.ID) + url, err := host.ServiceURL(test.ID) + if (err != nil || test.err != "") && + (err == nil || !strings.Contains(err.Error(), test.err)) { + t.Fatalf("unexpected service URL error: %s", err) + } + var got string if url != nil { got = url.String() @@ -47,8 +55,8 @@ func TestHostServiceURL(t *testing.T) { got = "" } - if got != test.Want { - t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.Want) + if got != test.want { + t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want) } }) }