Remove svchost package

This commit is contained in:
Radek Simko 2019-10-11 10:48:42 +01:00
parent 32f9722d9d
commit cd21a3859d
No known key found for this signature in database
GPG Key ID: 1F1C84FE689A88D7
20 changed files with 0 additions and 2923 deletions

View File

@ -1,61 +0,0 @@
package auth
import (
"github.com/hashicorp/terraform/svchost"
)
// CachingCredentialsSource creates a new credentials source that wraps another
// and caches its results in memory, on a per-hostname basis.
//
// No means is provided for expiration of cached credentials, so a caching
// credentials source should have a limited lifetime (one Terraform operation,
// for example) to ensure that time-limited credentials don't expire before
// their cache entries do.
func CachingCredentialsSource(source CredentialsSource) CredentialsSource {
return &cachingCredentialsSource{
source: source,
cache: map[svchost.Hostname]HostCredentials{},
}
}
type cachingCredentialsSource struct {
source CredentialsSource
cache map[svchost.Hostname]HostCredentials
}
// ForHost passes the given hostname on to the wrapped credentials source and
// caches the result to return for future requests with the same hostname.
//
// Both credentials and non-credentials (nil) responses are cached.
//
// No cache entry is created if the wrapped source returns an error, to allow
// the caller to retry the failing operation.
func (s *cachingCredentialsSource) ForHost(host svchost.Hostname) (HostCredentials, error) {
if cache, cached := s.cache[host]; cached {
return cache, nil
}
result, err := s.source.ForHost(host)
if err != nil {
return result, err
}
s.cache[host] = result
return result, nil
}
func (s *cachingCredentialsSource) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
// We'll delete the cache entry even if the store fails, since that just
// means that the next read will go to the real store and get a chance to
// see which object (old or new) is actually present.
delete(s.cache, host)
return s.source.StoreForHost(host, credentials)
}
func (s *cachingCredentialsSource) ForgetForHost(host svchost.Hostname) error {
// We'll delete the cache entry even if the store fails, since that just
// means that the next read will go to the real store and get a chance to
// see if the object is still present.
delete(s.cache, host)
return s.source.ForgetForHost(host)
}

View File

@ -1,118 +0,0 @@
// Package auth contains types and functions to manage authentication
// credentials for service hosts.
package auth
import (
"fmt"
"net/http"
"github.com/zclconf/go-cty/cty"
"github.com/hashicorp/terraform/svchost"
)
// Credentials is a list of CredentialsSource objects that can be tried in
// turn until one returns credentials for a host, or one returns an error.
//
// A Credentials is itself a CredentialsSource, wrapping its members.
// In principle one CredentialsSource can be nested inside another, though
// there is no good reason to do so.
//
// The write operations on a Credentials are tried only on the first object,
// under the assumption that it is the primary store.
type Credentials []CredentialsSource
// NoCredentials is an empty CredentialsSource that always returns nil
// when asked for credentials.
var NoCredentials CredentialsSource = Credentials{}
// A CredentialsSource is an object that may be able to provide credentials
// for a given host.
//
// Credentials lookups are not guaranteed to be concurrency-safe. Callers
// using these facilities in concurrent code must use external concurrency
// primitives to prevent race conditions.
type CredentialsSource interface {
// ForHost returns a non-nil HostCredentials if the source has credentials
// available for the host, and a nil HostCredentials if it does not.
//
// If an error is returned, progress through a list of CredentialsSources
// is halted and the error is returned to the user.
ForHost(host svchost.Hostname) (HostCredentials, error)
// StoreForHost takes a HostCredentialsWritable and saves it as the
// credentials for the given host.
//
// If credentials are already stored for the given host, it will try to
// replace those credentials but may produce an error if such replacement
// is not possible.
StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error
// ForgetForHost discards any stored credentials for the given host. It
// does nothing and returns successfully if no credentials are saved
// for that host.
ForgetForHost(host svchost.Hostname) error
}
// HostCredentials represents a single set of credentials for a particular
// host.
type HostCredentials interface {
// PrepareRequest modifies the given request in-place to apply the
// receiving credentials. The usual behavior of this method is to
// add some sort of Authorization header to the request.
PrepareRequest(req *http.Request)
// Token returns the authentication token.
Token() string
}
// HostCredentialsWritable is an extension of HostCredentials for credentials
// objects that can be serialized as a JSON-compatible object value for
// storage.
type HostCredentialsWritable interface {
HostCredentials
// ToStore returns a cty.Value, always of an object type,
// representing data that can be serialized to represent this object
// in persistent storage.
//
// The resulting value may uses only cty values that can be accepted
// by the cty JSON encoder, though the caller may elect to instead store
// it in some other format that has a JSON-compatible type system.
ToStore() cty.Value
}
// ForHost iterates over the contained CredentialsSource objects and
// tries to obtain credentials for the given host from each one in turn.
//
// If any source returns either a non-nil HostCredentials or a non-nil error
// then this result is returned. Otherwise, the result is nil, nil.
func (c Credentials) ForHost(host svchost.Hostname) (HostCredentials, error) {
for _, source := range c {
creds, err := source.ForHost(host)
if creds != nil || err != nil {
return creds, err
}
}
return nil, nil
}
// StoreForHost passes the given arguments to the same operation on the
// first CredentialsSource in the receiver.
func (c Credentials) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
if len(c) == 0 {
return fmt.Errorf("no credentials store is available")
}
return c[0].StoreForHost(host, credentials)
}
// ForgetForHost passes the given arguments to the same operation on the
// first CredentialsSource in the receiver.
func (c Credentials) ForgetForHost(host svchost.Hostname) error {
if len(c) == 0 {
return fmt.Errorf("no credentials store is available")
}
return c[0].ForgetForHost(host)
}

View File

@ -1,48 +0,0 @@
package auth
import (
"github.com/zclconf/go-cty/cty"
)
// HostCredentialsFromMap converts a map of key-value pairs from a credentials
// definition provided by the user (e.g. in a config file, or via a credentials
// helper) into a HostCredentials object if possible, or returns nil if
// no credentials could be extracted from the map.
//
// This function ignores map keys it is unfamiliar with, to allow for future
// expansion of the credentials map format for new credential types.
func HostCredentialsFromMap(m map[string]interface{}) HostCredentials {
if m == nil {
return nil
}
if token, ok := m["token"].(string); ok {
return HostCredentialsToken(token)
}
return nil
}
// HostCredentialsFromObject converts a cty.Value of an object type into a
// HostCredentials object if possible, or returns nil if no credentials could
// be extracted from the map.
//
// This function ignores object attributes it is unfamiliar with, to allow for
// future expansion of the credentials object structure for new credential types.
//
// If the given value is not of an object type, this function will panic.
func HostCredentialsFromObject(obj cty.Value) HostCredentials {
if !obj.Type().HasAttribute("token") {
return nil
}
tokenV := obj.GetAttr("token")
if tokenV.IsNull() || !tokenV.IsKnown() {
return nil
}
if !cty.String.Equals(tokenV.Type()) {
// Weird, but maybe some future Terraform version accepts an object
// here for some reason, so we'll be resilient.
return nil
}
return HostCredentialsToken(tokenV.AsString())
}

View File

@ -1,149 +0,0 @@
package auth
import (
"bytes"
"encoding/json"
"fmt"
"os/exec"
"path/filepath"
ctyjson "github.com/zclconf/go-cty/cty/json"
"github.com/hashicorp/terraform/svchost"
)
type helperProgramCredentialsSource struct {
executable string
args []string
}
// HelperProgramCredentialsSource returns a CredentialsSource that runs the
// given program with the given arguments in order to obtain credentials.
//
// The given executable path must be an absolute path; it is the caller's
// responsibility to validate and process a relative path or other input
// provided by an end-user. If the given path is not absolute, this
// function will panic.
//
// When credentials are requested, the program will be run in a child process
// with the given arguments along with two additional arguments added to the
// end of the list: the literal string "get", followed by the requested
// hostname in ASCII compatibility form (punycode form).
func HelperProgramCredentialsSource(executable string, args ...string) CredentialsSource {
if !filepath.IsAbs(executable) {
panic("NewCredentialsSourceHelperProgram requires absolute path to executable")
}
fullArgs := make([]string, len(args)+1)
fullArgs[0] = executable
copy(fullArgs[1:], args)
return &helperProgramCredentialsSource{
executable: executable,
args: fullArgs,
}
}
func (s *helperProgramCredentialsSource) ForHost(host svchost.Hostname) (HostCredentials, error) {
args := make([]string, len(s.args), len(s.args)+2)
copy(args, s.args)
args = append(args, "get")
args = append(args, string(host))
outBuf := bytes.Buffer{}
errBuf := bytes.Buffer{}
cmd := exec.Cmd{
Path: s.executable,
Args: args,
Stdin: nil,
Stdout: &outBuf,
Stderr: &errBuf,
}
err := cmd.Run()
if _, isExitErr := err.(*exec.ExitError); isExitErr {
errText := errBuf.String()
if errText == "" {
// Shouldn't happen for a well-behaved helper program
return nil, fmt.Errorf("error in %s, but it produced no error message", s.executable)
}
return nil, fmt.Errorf("error in %s: %s", s.executable, errText)
} else if err != nil {
return nil, fmt.Errorf("failed to run %s: %s", s.executable, err)
}
var m map[string]interface{}
err = json.Unmarshal(outBuf.Bytes(), &m)
if err != nil {
return nil, fmt.Errorf("malformed output from %s: %s", s.executable, err)
}
return HostCredentialsFromMap(m), nil
}
func (s *helperProgramCredentialsSource) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
args := make([]string, len(s.args), len(s.args)+2)
copy(args, s.args)
args = append(args, "store")
args = append(args, string(host))
toStore := credentials.ToStore()
toStoreRaw, err := ctyjson.Marshal(toStore, toStore.Type())
if err != nil {
return fmt.Errorf("can't serialize credentials to store: %s", err)
}
inReader := bytes.NewReader(toStoreRaw)
errBuf := bytes.Buffer{}
cmd := exec.Cmd{
Path: s.executable,
Args: args,
Stdin: inReader,
Stderr: &errBuf,
Stdout: nil,
}
err = cmd.Run()
if _, isExitErr := err.(*exec.ExitError); isExitErr {
errText := errBuf.String()
if errText == "" {
// Shouldn't happen for a well-behaved helper program
return fmt.Errorf("error in %s, but it produced no error message", s.executable)
}
return fmt.Errorf("error in %s: %s", s.executable, errText)
} else if err != nil {
return fmt.Errorf("failed to run %s: %s", s.executable, err)
}
return nil
}
func (s *helperProgramCredentialsSource) ForgetForHost(host svchost.Hostname) error {
args := make([]string, len(s.args), len(s.args)+2)
copy(args, s.args)
args = append(args, "forget")
args = append(args, string(host))
errBuf := bytes.Buffer{}
cmd := exec.Cmd{
Path: s.executable,
Args: args,
Stdin: nil,
Stderr: &errBuf,
Stdout: nil,
}
err := cmd.Run()
if _, isExitErr := err.(*exec.ExitError); isExitErr {
errText := errBuf.String()
if errText == "" {
// Shouldn't happen for a well-behaved helper program
return fmt.Errorf("error in %s, but it produced no error message", s.executable)
}
return fmt.Errorf("error in %s: %s", s.executable, errText)
} else if err != nil {
return fmt.Errorf("failed to run %s: %s", s.executable, err)
}
return nil
}

View File

@ -1,83 +0,0 @@
package auth
import (
"os"
"path/filepath"
"testing"
"github.com/hashicorp/terraform/svchost"
)
func TestHelperProgramCredentialsSource(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
program := filepath.Join(wd, "testdata/test-helper")
t.Logf("testing with helper at %s", program)
src := HelperProgramCredentialsSource(program)
t.Run("happy path", func(t *testing.T) {
creds, err := src.ForHost(svchost.Hostname("example.com"))
if err != nil {
t.Fatal(err)
}
if tokCreds, isTok := creds.(HostCredentialsToken); isTok {
if got, want := string(tokCreds), "example-token"; got != want {
t.Errorf("wrong token %q; want %q", got, want)
}
} else {
t.Errorf("wrong type of credentials %T", creds)
}
})
t.Run("no credentials", func(t *testing.T) {
creds, err := src.ForHost(svchost.Hostname("nothing.example.com"))
if err != nil {
t.Fatal(err)
}
if creds != nil {
t.Errorf("got credentials; want nil")
}
})
t.Run("unsupported credentials type", func(t *testing.T) {
creds, err := src.ForHost(svchost.Hostname("other-cred-type.example.com"))
if err != nil {
t.Fatal(err)
}
if creds != nil {
t.Errorf("got credentials; want nil")
}
})
t.Run("lookup error", func(t *testing.T) {
_, err := src.ForHost(svchost.Hostname("fail.example.com"))
if err == nil {
t.Error("completed successfully; want error")
}
})
t.Run("store happy path", func(t *testing.T) {
err := src.StoreForHost(svchost.Hostname("example.com"), HostCredentialsToken("example-token"))
if err != nil {
t.Fatal(err)
}
})
t.Run("store error", func(t *testing.T) {
err := src.StoreForHost(svchost.Hostname("fail.example.com"), HostCredentialsToken("example-token"))
if err == nil {
t.Error("completed successfully; want error")
}
})
t.Run("forget happy path", func(t *testing.T) {
err := src.ForgetForHost(svchost.Hostname("example.com"))
if err != nil {
t.Fatal(err)
}
})
t.Run("forget error", func(t *testing.T) {
err := src.ForgetForHost(svchost.Hostname("fail.example.com"))
if err == nil {
t.Error("completed successfully; want error")
}
})
}

View File

@ -1,38 +0,0 @@
package auth
import (
"fmt"
"github.com/hashicorp/terraform/svchost"
)
// StaticCredentialsSource is a credentials source that retrieves credentials
// from the provided map. It returns nil if a requested hostname is not
// present in the map.
//
// The caller should not modify the given map after passing it to this function.
func StaticCredentialsSource(creds map[svchost.Hostname]map[string]interface{}) CredentialsSource {
return staticCredentialsSource(creds)
}
type staticCredentialsSource map[svchost.Hostname]map[string]interface{}
func (s staticCredentialsSource) ForHost(host svchost.Hostname) (HostCredentials, error) {
if s == nil {
return nil, nil
}
if m, exists := s[host]; exists {
return HostCredentialsFromMap(m), nil
}
return nil, nil
}
func (s staticCredentialsSource) StoreForHost(host svchost.Hostname, credentials HostCredentialsWritable) error {
return fmt.Errorf("can't store new credentials in a static credentials source")
}
func (s staticCredentialsSource) ForgetForHost(host svchost.Hostname) error {
return fmt.Errorf("can't discard credentials from a static credentials source")
}

View File

@ -1,38 +0,0 @@
package auth
import (
"testing"
"github.com/hashicorp/terraform/svchost"
)
func TestStaticCredentialsSource(t *testing.T) {
src := StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
svchost.Hostname("example.com"): map[string]interface{}{
"token": "abc123",
},
})
t.Run("exists", func(t *testing.T) {
creds, err := src.ForHost(svchost.Hostname("example.com"))
if err != nil {
t.Fatal(err)
}
if tokCreds, isToken := creds.(HostCredentialsToken); isToken {
if got, want := string(tokCreds), "abc123"; got != want {
t.Errorf("wrong token %q; want %q", got, want)
}
} else {
t.Errorf("creds is %#v; want HostCredentialsToken", creds)
}
})
t.Run("does not exist", func(t *testing.T) {
creds, err := src.ForHost(svchost.Hostname("example.net"))
if err != nil {
t.Fatal(err)
}
if creds != nil {
t.Errorf("creds is %#v; want nil", creds)
}
})
}

View File

@ -1 +0,0 @@
main

View File

@ -1,64 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
)
// This is a simple program that implements the "helper program" protocol
// for the svchost/auth package for unit testing purposes.
func main() {
args := os.Args
if len(args) < 3 {
die("not enough arguments\n")
}
host := args[2]
switch args[1] {
case "get":
switch host {
case "example.com":
fmt.Print(`{"token":"example-token"}`)
case "other-cred-type.example.com":
fmt.Print(`{"username":"alfred"}`) // unrecognized by main program
case "fail.example.com":
die("failing because you told me to fail\n")
default:
fmt.Print("{}") // no credentials available
}
case "store":
dataSrc, err := ioutil.ReadAll(os.Stdin)
if err != nil {
die("invalid input: %s", err)
}
var data map[string]interface{}
err = json.Unmarshal(dataSrc, &data)
switch host {
case "example.com":
if data["token"] != "example-token" {
die("incorrect token value to store")
}
default:
die("can't store credentials for %s", host)
}
case "forget":
switch host {
case "example.com":
// okay!
default:
die("can't forget credentials for %s", host)
}
default:
die("unknown subcommand %q\n", args[1])
}
}
func die(f string, args ...interface{}) {
fmt.Fprintf(os.Stderr, fmt.Sprintf(f, args...))
os.Exit(1)
}

View File

@ -1,7 +0,0 @@
#!/usr/bin/env bash
set -eu
cd "$( dirname "${BASH_SOURCE[0]}" )"
[ -x main ] || go build -o main .
exec ./main "$@"

View File

@ -1,43 +0,0 @@
package auth
import (
"net/http"
"github.com/zclconf/go-cty/cty"
)
// HostCredentialsToken is a HostCredentials implementation that represents a
// single "bearer token", to be sent to the server via an Authorization header
// with the auth type set to "Bearer".
//
// To save a token as the credentials for a host, convert the token string to
// this type and use the result as a HostCredentialsWritable implementation.
type HostCredentialsToken string
// Interface implementation assertions. Compilation will fail here if
// HostCredentialsToken does not fully implement these interfaces.
var _ HostCredentials = HostCredentialsToken("")
var _ HostCredentialsWritable = HostCredentialsToken("")
// PrepareRequest alters the given HTTP request by setting its Authorization
// header to the string "Bearer " followed by the encapsulated authentication
// token.
func (tc HostCredentialsToken) PrepareRequest(req *http.Request) {
if req.Header == nil {
req.Header = http.Header{}
}
req.Header.Set("Authorization", "Bearer "+string(tc))
}
// Token returns the authentication token.
func (tc HostCredentialsToken) Token() string {
return string(tc)
}
// ToStore returns a credentials object with a single attribute "token" whose
// value is the token string.
func (tc HostCredentialsToken) ToStore() cty.Value {
return cty.ObjectVal(map[string]cty.Value{
"token": cty.StringVal(string(tc)),
})
}

View File

@ -1,31 +0,0 @@
package auth
import (
"net/http"
"testing"
"github.com/zclconf/go-cty/cty"
)
func TestHostCredentialsToken(t *testing.T) {
creds := HostCredentialsToken("foo-bar")
{
req := &http.Request{}
creds.PrepareRequest(req)
authStr := req.Header.Get("authorization")
if got, want := authStr, "Bearer foo-bar"; got != want {
t.Errorf("wrong Authorization header value %q; want %q", got, want)
}
}
{
got := creds.ToStore()
want := cty.ObjectVal(map[string]cty.Value{
"token": cty.StringVal("foo-bar"),
})
if !want.RawEquals(got) {
t.Errorf("wrong storable object value\ngot: %#v\nwant: %#v", got, want)
}
}
}

View File

@ -1,271 +0,0 @@
// Package disco handles Terraform's remote service discovery protocol.
//
// This protocol allows mapping from a service hostname, as produced by the
// svchost package, to a set of services supported by that host and the
// endpoint information for each supported service.
package disco
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"mime"
"net/http"
"net/url"
"time"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/terraform/httpclient"
"github.com/hashicorp/terraform/svchost"
"github.com/hashicorp/terraform/svchost/auth"
)
const (
// Fixed path to the discovery manifest.
discoPath = "/.well-known/terraform.json"
// Arbitrary-but-small number to prevent runaway redirect loops.
maxRedirects = 3
// Arbitrary-but-small time limit to prevent UI "hangs" during discovery.
discoTimeout = 11 * time.Second
// 1MB - to prevent abusive services from using loads of our memory.
maxDiscoDocBytes = 1 * 1024 * 1024
)
// httpTransport is overridden during tests, to skip TLS verification.
var httpTransport = cleanhttp.DefaultPooledTransport()
// Disco is the main type in this package, which allows discovery on given
// hostnames and caches the results by hostname to avoid repeated requests
// for the same information.
type Disco struct {
hostCache map[svchost.Hostname]*Host
credsSrc auth.CredentialsSource
// Transport is a custom http.RoundTripper to use.
Transport http.RoundTripper
}
// New returns a new initialized discovery object.
func New() *Disco {
return NewWithCredentialsSource(nil)
}
// NewWithCredentialsSource returns a new discovery object initialized with
// the given credentials source.
func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
return &Disco{
hostCache: make(map[svchost.Hostname]*Host),
credsSrc: credsSrc,
Transport: httpTransport,
}
}
// SetCredentialsSource provides a credentials source that will be used to
// add credentials to outgoing discovery requests, where available.
//
// If this method is never called, no outgoing discovery requests will have
// credentials.
func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
d.credsSrc = src
}
// CredentialsSource returns the credentials source associated with the receiver,
// or an empty credentials source if none is associated.
func (d *Disco) CredentialsSource() auth.CredentialsSource {
if d.credsSrc == nil {
// We'll return an empty one just to save the caller from having to
// protect against the nil case, since this interface already allows
// for the possibility of there being no credentials at all.
return auth.StaticCredentialsSource(nil)
}
return d.credsSrc
}
// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
// credentials available for the host, and a nil HostCredentials if it does not.
func (d *Disco) CredentialsForHost(hostname svchost.Hostname) (auth.HostCredentials, error) {
if d.credsSrc == nil {
return nil, nil
}
return d.credsSrc.ForHost(hostname)
}
// ForceHostServices provides a pre-defined set of services for a given
// host, which prevents the receiver from attempting network-based discovery
// for the given host. Instead, the given services map will be returned
// verbatim.
//
// When providing "forced" services, any relative URLs are resolved against
// the initial discovery URL that would have been used for network-based
// discovery, yielding the same results as if the given map were published
// at the host's default discovery URL, though using absolute URLs is strongly
// recommended to make the configured behavior more explicit.
func (d *Disco) ForceHostServices(hostname svchost.Hostname, services map[string]interface{}) {
if services == nil {
services = map[string]interface{}{}
}
d.hostCache[hostname] = &Host{
discoURL: &url.URL{
Scheme: "https",
Host: string(hostname),
Path: discoPath,
},
hostname: hostname.ForDisplay(),
services: services,
transport: d.Transport,
}
}
// Discover runs the discovery protocol against the given hostname (which must
// already have been validated and prepared with svchost.ForComparison) and
// returns an object describing the services available at that host.
//
// If a given hostname supports no Terraform services at all, a non-nil but
// empty Host object is returned. When giving feedback to the end user about
// such situations, we say "host <name> does not provide a <service> service",
// regardless of whether that is due to that service specifically being absent
// or due to the host not providing Terraform services at all, since we don't
// wish to expose the detail of whole-host discovery to an end-user.
func (d *Disco) Discover(hostname svchost.Hostname) (*Host, error) {
if host, cached := d.hostCache[hostname]; cached {
return host, nil
}
host, err := d.discover(hostname)
if err != nil {
return nil, err
}
d.hostCache[hostname] = host
return host, nil
}
// DiscoverServiceURL is a convenience wrapper for discovery on a given
// hostname and then looking up a particular service in the result.
func (d *Disco) DiscoverServiceURL(hostname svchost.Hostname, serviceID string) (*url.URL, error) {
host, err := d.Discover(hostname)
if err != nil {
return nil, err
}
return host.ServiceURL(serviceID)
}
// discover implements the actual discovery process, with its result cached
// by the public-facing Discover method.
func (d *Disco) discover(hostname svchost.Hostname) (*Host, error) {
discoURL := &url.URL{
Scheme: "https",
Host: hostname.String(),
Path: discoPath,
}
client := &http.Client{
Transport: d.Transport,
Timeout: discoTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
log.Printf("[DEBUG] Service discovery redirected to %s", req.URL)
if len(via) > maxRedirects {
return errors.New("too many redirects") // this error will never actually be seen
}
return nil
},
}
req := &http.Request{
Header: make(http.Header),
Method: "GET",
URL: discoURL,
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", httpclient.UserAgentString())
creds, err := d.CredentialsForHost(hostname)
if err != nil {
log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", hostname, err)
}
if creds != nil {
// Update the request to include credentials.
creds.PrepareRequest(req)
}
log.Printf("[DEBUG] Service discovery for %s at %s", hostname, discoURL)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("Failed to request discovery document: %v", err)
}
defer resp.Body.Close()
host := &Host{
// Use the discovery URL from resp.Request in
// case the client followed any redirects.
discoURL: resp.Request.URL,
hostname: hostname.ForDisplay(),
transport: d.Transport,
}
// Return the host without any services.
if resp.StatusCode == 404 {
return host, nil
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("Failed to request discovery document: %s", resp.Status)
}
contentType := resp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, fmt.Errorf("Discovery URL has a malformed Content-Type %q", contentType)
}
if mediaType != "application/json" {
return nil, fmt.Errorf("Discovery URL returned an unsupported Content-Type %q", mediaType)
}
// This doesn't catch chunked encoding, because ContentLength is -1 in that case.
if resp.ContentLength > maxDiscoDocBytes {
// Size limit here is not a contractual requirement and so we may
// adjust it over time if we find a different limit is warranted.
return nil, fmt.Errorf(
"Discovery doc response is too large (got %d bytes; limit %d)",
resp.ContentLength, maxDiscoDocBytes,
)
}
// If the response is using chunked encoding then we can't predict its
// size, but we'll at least prevent reading the entire thing into memory.
lr := io.LimitReader(resp.Body, maxDiscoDocBytes)
servicesBytes, err := ioutil.ReadAll(lr)
if err != nil {
return nil, fmt.Errorf("Error reading discovery document body: %v", err)
}
var services map[string]interface{}
err = json.Unmarshal(servicesBytes, &services)
if err != nil {
return nil, fmt.Errorf("Failed to decode discovery document as a JSON object: %v", err)
}
host.services = services
return host, nil
}
// Forget invalidates any cached record of the given hostname. If the host
// has no cache entry then this is a no-op.
func (d *Disco) Forget(hostname svchost.Hostname) {
delete(d.hostCache, hostname)
}
// ForgetAll is like Forget, but for all of the hostnames that have cache entries.
func (d *Disco) ForgetAll() {
d.hostCache = make(map[svchost.Hostname]*Host)
}

View File

@ -1,357 +0,0 @@
package disco
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strconv"
"testing"
"github.com/hashicorp/terraform/svchost"
"github.com/hashicorp/terraform/svchost/auth"
)
func TestMain(m *testing.M) {
// During all tests we override the HTTP transport we use for discovery
// so it'll tolerate the locally-generated TLS certificates we use
// for test URLs.
httpTransport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
os.Exit(m.Run())
}
func TestDiscover(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"thingy.v1": "http://example.com/foo",
"wotsit.v2": "http://example.net/bar"
}
`)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
})
t.Run("chunked encoding", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"thingy.v1": "http://example.com/foo",
"wotsit.v2": "http://example.net/bar"
}
`)
w.Header().Add("Content-Type", "application/json")
// We're going to force chunked encoding here -- and thus prevent
// the server from predicting the length -- so we can make sure
// our client is tolerant of servers using this encoding.
w.Write(resp[:5])
w.(http.Flusher).Flush()
w.Write(resp[5:])
w.(http.Flusher).Flush()
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2")
}
if got, want := gotURL.String(), "http://example.net/bar"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
})
t.Run("with credentials", func(t *testing.T) {
var authHeaderText string
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{}`)
authHeaderText = r.Header.Get("Authorization")
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
host: map[string]interface{}{
"token": "abc123",
},
}))
d.Discover(host)
if got, want := authHeaderText, "Bearer abc123"; got != want {
t.Fatalf("wrong Authorization header\ngot: %s\nwant: %s", got, want)
}
})
t.Run("forced services override", func(t *testing.T) {
forced := map[string]interface{}{
"thingy.v1": "http://example.net/foo",
"wotsit.v2": "/foo",
}
d := New()
d.ForceHostServices(svchost.Hostname("example.com"), forced)
givenHost := "example.com"
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
{
gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.net/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
}
{
gotURL, err := discovered.ServiceURL("wotsit.v2")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for wotsit.v2")
}
if got, want := gotURL.String(), "https://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
}
})
t.Run("not JSON", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
w.Header().Add("Content-Type", "application/octet-stream")
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err == nil {
t.Fatalf("expected a discovery error")
}
// Returned discovered should be nil.
if discovered != nil {
t.Errorf("discovered not nil; should be")
}
})
t.Run("malformed JSON", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "htt`) // truncated, for example...
w.Header().Add("Content-Type", "application/json")
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err == nil {
t.Fatalf("expected a discovery error")
}
// Returned discovered should be nil.
if discovered != nil {
t.Errorf("discovered not nil; should be")
}
})
t.Run("JSON with redundant charset", func(t *testing.T) {
// The JSON RFC defines no parameters for the application/json
// MIME type, but some servers have a weird tendency to just add
// "charset" to everything, so we'll make sure we ignore it successfully.
// (JSON uses content sniffing for encoding detection, not media type params.)
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
w.Header().Add("Content-Type", "application/json; charset=latin-1")
w.Write(resp)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
if discovered.services == nil {
t.Errorf("response is empty; shouldn't be")
}
})
t.Run("no discovery doc", func(t *testing.T) {
portStr, close := testServer(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
})
defer close()
givenHost := "localhost" + portStr
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
// Returned discovered.services should be nil (empty).
if discovered.services != nil {
t.Errorf("discovered.services not nil (empty); should be")
}
})
t.Run("redirect", func(t *testing.T) {
// For this test, we have two servers and one redirects to the other
portStr1, close1 := testServer(func(w http.ResponseWriter, r *http.Request) {
// This server is the one that returns a real response.
resp := []byte(`{"thingy.v1": "http://example.com/foo"}`)
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
portStr2, close2 := testServer(func(w http.ResponseWriter, r *http.Request) {
// This server is the one that redirects.
http.Redirect(w, r, "https://127.0.0.1"+portStr1+"/.well-known/terraform.json", 302)
})
defer close1()
defer close2()
givenHost := "localhost" + portStr2
host, err := svchost.ForComparison(givenHost)
if err != nil {
t.Fatalf("test server hostname is invalid: %s", err)
}
d := New()
discovered, err := d.Discover(host)
if err != nil {
t.Fatalf("unexpected discovery error: %s", err)
}
gotURL, err := discovered.ServiceURL("thingy.v1")
if err != nil {
t.Fatalf("unexpected service URL error: %s", err)
}
if gotURL == nil {
t.Fatalf("found no URL for thingy.v1")
}
if got, want := gotURL.String(), "http://example.com/foo"; got != want {
t.Fatalf("wrong result %q; want %q", got, want)
}
// The base URL for the host object should be the URL we redirected to,
// rather than the we redirected _from_.
gotBaseURL := discovered.discoURL.String()
wantBaseURL := "https://127.0.0.1" + portStr1 + "/.well-known/terraform.json"
if gotBaseURL != wantBaseURL {
t.Errorf("incorrect base url %s; want %s", gotBaseURL, wantBaseURL)
}
})
}
func testServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
server := httptest.NewTLSServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Test server always returns 404 if the URL isn't what we expect
if r.URL.Path != "/.well-known/terraform.json" {
w.WriteHeader(404)
w.Write([]byte("not found"))
return
}
// If the URL is correct then the given hander decides the response
h(w, r)
},
))
serverURL, _ := url.Parse(server.URL)
portStr = serverURL.Port()
if portStr != "" {
portStr = ":" + portStr
}
close = func() {
server.Close()
}
return portStr, close
}

View File

@ -1,414 +0,0 @@
package disco
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-version"
"github.com/hashicorp/terraform/httpclient"
)
const versionServiceID = "versions.v1"
// Host represents a service discovered host.
type Host struct {
discoURL *url.URL
hostname string
services map[string]interface{}
transport http.RoundTripper
}
// Constraints represents the version constraints of a service.
type Constraints struct {
Service string `json:"service"`
Product string `json:"product"`
Minimum string `json:"minimum"`
Maximum string `json:"maximum"`
Excluding []string `json:"excluding"`
}
// ErrServiceNotProvided is returned when the service is not provided.
type ErrServiceNotProvided struct {
hostname string
service string
}
// Error returns a customized error message.
func (e *ErrServiceNotProvided) Error() string {
if e.hostname == "" {
return fmt.Sprintf("host does not provide a %s service", e.service)
}
return fmt.Sprintf("host %s does not provide a %s service", e.hostname, e.service)
}
// ErrVersionNotSupported is returned when the version is not supported.
type ErrVersionNotSupported struct {
hostname string
service string
version string
}
// Error returns a customized error message.
func (e *ErrVersionNotSupported) Error() string {
if e.hostname == "" {
return fmt.Sprintf("host does not support %s version %s", e.service, e.version)
}
return fmt.Sprintf("host %s does not support %s version %s", e.hostname, e.service, e.version)
}
// ErrNoVersionConstraints is returned when checkpoint was disabled
// or the endpoint to query for version constraints was unavailable.
type ErrNoVersionConstraints struct {
disabled bool
}
// Error returns a customized error message.
func (e *ErrNoVersionConstraints) Error() string {
if e.disabled {
return "checkpoint disabled"
}
return "unable to contact versions service"
}
// ServiceURL returns the URL associated with the given service identifier,
// which should be of the form "servicename.vN".
//
// A non-nil result is always an absolute URL with a scheme of either HTTPS
// or HTTP.
func (h *Host) ServiceURL(id string) (*url.URL, error) {
svc, ver, err := parseServiceID(id)
if err != nil {
return nil, err
}
// No services supported for an empty Host.
if h == nil || h.services == nil {
return nil, &ErrServiceNotProvided{service: svc}
}
urlStr, ok := h.services[id].(string)
if !ok {
// See if we have a matching service as that would indicate
// the service is supported, but not the requested version.
for serviceID := range h.services {
if strings.HasPrefix(serviceID, svc+".") {
return nil, &ErrVersionNotSupported{
hostname: h.hostname,
service: svc,
version: ver.Original(),
}
}
}
// No discovered services match the requested service.
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
}
u, err := h.parseURL(urlStr)
if err != nil {
return nil, fmt.Errorf("Failed to parse service URL: %v", err)
}
return u, nil
}
// ServiceOAuthClient returns the OAuth client configuration associated with the
// given service identifier, which should be of the form "servicename.vN".
//
// This is an alternative to ServiceURL for unusual services that require
// a full OAuth2 client definition rather than just a URL. Use this only
// for services whose specification calls for this sort of definition.
func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) {
svc, ver, err := parseServiceID(id)
if err != nil {
return nil, err
}
// No services supported for an empty Host.
if h == nil || h.services == nil {
return nil, &ErrServiceNotProvided{service: svc}
}
if _, ok := h.services[id]; !ok {
// See if we have a matching service as that would indicate
// the service is supported, but not the requested version.
for serviceID := range h.services {
if strings.HasPrefix(serviceID, svc+".") {
return nil, &ErrVersionNotSupported{
hostname: h.hostname,
service: svc,
version: ver.Original(),
}
}
}
// No discovered services match the requested service.
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
}
var raw map[string]interface{}
switch v := h.services[id].(type) {
case map[string]interface{}:
raw = v // Great!
case []map[string]interface{}:
// An absolutely infuriating legacy HCL ambiguity.
raw = v[0]
default:
// Debug message because raw Go types don't belong in our UI.
log.Printf("[DEBUG] The definition for %s has Go type %T", id, h.services[id])
return nil, fmt.Errorf("Service %s must be declared with an object value in the service discovery document", id)
}
var grantTypes OAuthGrantTypeSet
if rawGTs, ok := raw["grant_types"]; ok {
if gts, ok := rawGTs.([]interface{}); ok {
var kws []string
for _, gtI := range gts {
gt, ok := gtI.(string)
if !ok {
// We'll ignore this so that we can potentially introduce
// other types into this array later if we need to.
continue
}
kws = append(kws, gt)
}
grantTypes = NewOAuthGrantTypeSet(kws...)
} else {
return nil, fmt.Errorf("Service %s is defined with invalid grant_types property: must be an array of grant type strings", id)
}
} else {
grantTypes = NewOAuthGrantTypeSet("authz_code")
}
ret := &OAuthClient{
SupportedGrantTypes: grantTypes,
}
if clientIDStr, ok := raw["client"].(string); ok {
ret.ID = clientIDStr
} else {
return nil, fmt.Errorf("Service %s definition is missing required property \"client\"", id)
}
if urlStr, ok := raw["authz"].(string); ok {
u, err := h.parseURL(urlStr)
if err != nil {
return nil, fmt.Errorf("Failed to parse authorization URL: %v", err)
}
ret.AuthorizationURL = u
} else {
if grantTypes.RequiresAuthorizationEndpoint() {
return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id)
}
}
if urlStr, ok := raw["token"].(string); ok {
u, err := h.parseURL(urlStr)
if err != nil {
return nil, fmt.Errorf("Failed to parse token URL: %v", err)
}
ret.TokenURL = u
} else {
if grantTypes.RequiresTokenEndpoint() {
return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id)
}
}
if portsRaw, ok := raw["ports"].([]interface{}); ok {
if len(portsRaw) != 2 {
return nil, fmt.Errorf("Invalid \"ports\" definition for service %s: must be a two-element array", id)
}
invalidPortsErr := fmt.Errorf("Invalid \"ports\" definition for service %s: both ports must be whole numbers between 1024 and 65535", id)
ports := make([]uint16, 2)
for i := range ports {
switch v := portsRaw[i].(type) {
case float64:
// JSON unmarshaling always produces float64. HCL 2 might, if
// an invalid fractional number were given.
if float64(uint16(v)) != v || v < 1024 {
return nil, invalidPortsErr
}
ports[i] = uint16(v)
case int:
// Legacy HCL produces int. HCL 2 will too, if the given number
// is a whole number.
if v < 1024 || v > 65535 {
return nil, invalidPortsErr
}
ports[i] = uint16(v)
default:
// Debug message because raw Go types don't belong in our UI.
log.Printf("[DEBUG] Port value %d has Go type %T", i, portsRaw[i])
return nil, invalidPortsErr
}
}
if ports[1] < ports[0] {
return nil, fmt.Errorf("Invalid \"ports\" definition for service %s: minimum port cannot be greater than maximum port", id)
}
ret.MinPort = ports[0]
ret.MaxPort = ports[1]
} else {
// Default is to accept any port in the range, for a client that is
// able to call back to any localhost port.
ret.MinPort = 1024
ret.MaxPort = 65535
}
return ret, nil
}
func (h *Host) parseURL(urlStr string) (*url.URL, error) {
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
// Make relative URLs absolute using our discovery URL.
if !u.IsAbs() {
u = h.discoURL.ResolveReference(u)
}
if u.Scheme != "https" && u.Scheme != "http" {
return nil, fmt.Errorf("unsupported scheme %s", u.Scheme)
}
if u.User != nil {
return nil, fmt.Errorf("embedded username/password information is not permitted")
}
// Fragment part is irrelevant, since we're not a browser.
u.Fragment = ""
return u, nil
}
// VersionConstraints returns the contraints for a given service identifier
// (which should be of the form "servicename.vN") and product.
//
// When an exact (service and version) match is found, the constraints for
// that service are returned.
//
// When the requested version is not provided but the service is, we will
// search for all alternative versions. If mutliple alternative versions
// are found, the contrains of the latest available version are returned.
//
// When a service is not provided at all an error will be returned instead.
//
// When checkpoint is disabled or when a 404 is returned after making the
// HTTP call, an ErrNoVersionConstraints error will be returned.
func (h *Host) VersionConstraints(id, product string) (*Constraints, error) {
svc, _, err := parseServiceID(id)
if err != nil {
return nil, err
}
// Return early if checkpoint is disabled.
if disabled := os.Getenv("CHECKPOINT_DISABLE"); disabled != "" {
return nil, &ErrNoVersionConstraints{disabled: true}
}
// No services supported for an empty Host.
if h == nil || h.services == nil {
return nil, &ErrServiceNotProvided{service: svc}
}
// Try to get the service URL for the version service and
// return early if the service isn't provided by the host.
u, err := h.ServiceURL(versionServiceID)
if err != nil {
return nil, err
}
// Check if we have an exact (service and version) match.
if _, ok := h.services[id].(string); !ok {
// If we don't have an exact match, we search for all matching
// services and then use the service ID of the latest version.
var services []string
for serviceID := range h.services {
if strings.HasPrefix(serviceID, svc+".") {
services = append(services, serviceID)
}
}
if len(services) == 0 {
// No discovered services match the requested service.
return nil, &ErrServiceNotProvided{hostname: h.hostname, service: svc}
}
// Set id to the latest service ID we found.
var latest *version.Version
for _, serviceID := range services {
if _, ver, err := parseServiceID(serviceID); err == nil {
if latest == nil || latest.LessThan(ver) {
id = serviceID
latest = ver
}
}
}
}
// Set a default timeout of 1 sec for the versions request (in milliseconds)
timeout := 1000
if v, err := strconv.Atoi(os.Getenv("CHECKPOINT_TIMEOUT")); err == nil {
timeout = v
}
client := &http.Client{
Transport: h.transport,
Timeout: time.Duration(timeout) * time.Millisecond,
}
// Prepare the service URL by setting the service and product.
v := u.Query()
v.Set("product", product)
u.Path += id
u.RawQuery = v.Encode()
// Create a new request.
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("Failed to create version constraints request: %v", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", httpclient.UserAgentString())
log.Printf("[DEBUG] Retrieve version constraints for service %s and product %s", id, product)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("Failed to request version constraints: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode == 404 {
return nil, &ErrNoVersionConstraints{disabled: false}
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("Failed to request version constraints: %s", resp.Status)
}
// Parse the constraints from the response body.
result := &Constraints{}
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
return nil, fmt.Errorf("Error parsing version constraints: %v", err)
}
return result, nil
}
func parseServiceID(id string) (string, *version.Version, error) {
parts := strings.SplitN(id, ".", 2)
if len(parts) != 2 {
return "", nil, fmt.Errorf("Invalid service ID format (i.e. service.vN): %s", id)
}
version, err := version.NewVersion(parts[1])
if err != nil {
return "", nil, fmt.Errorf("Invalid service version: %v", err)
}
return parts[0], version, nil
}

View File

@ -1,528 +0,0 @@
package disco
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"reflect"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestHostServiceURL(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
host := Host{
discoURL: baseURL,
hostname: "test-server",
services: map[string]interface{}{
"absolute.v1": "http://example.net/foo/bar",
"absolutewithport.v1": "http://example.net:8080/foo/bar",
"relative.v1": "./stu/",
"rootrelative.v1": "/baz",
"protorelative.v1": "//example.net/",
"withfragment.v1": "http://example.org/#foo",
"querystring.v1": "https://example.net/baz?foo=bar",
"nothttp.v1": "ftp://127.0.0.1/pub/",
"invalid.v1": "***not A URL at all!:/<@@@@>***",
},
}
tests := []struct {
ID string
want string
err string
}{
{"absolute.v1", "http://example.net/foo/bar", ""},
{"absolutewithport.v1", "http://example.net:8080/foo/bar", ""},
{"relative.v1", "https://example.com/disco/stu/", ""},
{"rootrelative.v1", "https://example.com/baz", ""},
{"protorelative.v1", "https://example.net/", ""},
{"withfragment.v1", "http://example.org/", ""},
{"querystring.v1", "https://example.net/baz?foo=bar", ""},
{"nothttp.v1", "<nil>", "unsupported scheme"},
{"invalid.v1", "<nil>", "Failed to parse service URL"},
}
for _, test := range tests {
t.Run(test.ID, func(t *testing.T) {
url, err := host.ServiceURL(test.ID)
if (err != nil || test.err != "") &&
(err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("unexpected service URL error: %s", err)
}
var got string
if url != nil {
got = url.String()
} else {
got = "<nil>"
}
if got != test.want {
t.Errorf("wrong result\ngot: %s\nwant: %s", got, test.want)
}
})
}
}
func TestHostServiceOAuthClient(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
host := Host{
discoURL: baseURL,
hostname: "test-server",
services: map[string]interface{}{
"explicitgranttype.v1": map[string]interface{}{
"client": "explicitgranttype",
"authz": "./authz",
"token": "./token",
"grant_types": []interface{}{"authz_code", "password", "tbd"},
},
"customports.v1": map[string]interface{}{
"client": "customports",
"authz": "./authz",
"token": "./token",
"ports": []interface{}{1025, 1026},
},
"invalidports.v1": map[string]interface{}{
"client": "invalidports",
"authz": "./authz",
"token": "./token",
"ports": []interface{}{1, 65535},
},
"missingauthz.v1": map[string]interface{}{
"client": "missingauthz",
"token": "./token",
},
"missingtoken.v1": map[string]interface{}{
"client": "missingtoken",
"authz": "./authz",
},
"passwordmissingauthz.v1": map[string]interface{}{
"client": "passwordmissingauthz",
"token": "./token",
"grant_types": []interface{}{"password"},
},
"absolute.v1": map[string]interface{}{
"client": "absolute",
"authz": "http://example.net/foo/authz",
"token": "http://example.net/foo/token",
},
"absolutewithport.v1": map[string]interface{}{
"client": "absolutewithport",
"authz": "http://example.net:8000/foo/authz",
"token": "http://example.net:8000/foo/token",
},
"relative.v1": map[string]interface{}{
"client": "relative",
"authz": "./authz",
"token": "./token",
},
"rootrelative.v1": map[string]interface{}{
"client": "rootrelative",
"authz": "/authz",
"token": "/token",
},
"protorelative.v1": map[string]interface{}{
"client": "protorelative",
"authz": "//example.net/authz",
"token": "//example.net/token",
},
"nothttp.v1": map[string]interface{}{
"client": "nothttp",
"authz": "ftp://127.0.0.1/pub/authz",
"token": "ftp://127.0.0.1/pub/token",
},
"invalidauthz.v1": map[string]interface{}{
"client": "invalidauthz",
"authz": "***not A URL at all!:/<@@@@>***",
"token": "/foo",
},
"invalidtoken.v1": map[string]interface{}{
"client": "invalidauthz",
"authz": "/foo",
"token": "***not A URL at all!:/<@@@@>***",
},
},
}
mustURL := func(t *testing.T, s string) *url.URL {
t.Helper()
u, err := url.Parse(s)
if err != nil {
t.Fatalf("invalid wanted URL %s in test case: %s", s, err)
}
return u
}
tests := []struct {
ID string
want *OAuthClient
err string
}{
{
"explicitgranttype.v1",
&OAuthClient{
ID: "explicitgranttype",
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code", "password", "tbd"),
},
"",
},
{
"customports.v1",
&OAuthClient{
ID: "customports",
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1025,
MaxPort: 1026,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"invalidports.v1",
nil,
`Invalid "ports" definition for service invalidports.v1: both ports must be whole numbers between 1024 and 65535`,
},
{
"missingauthz.v1",
nil,
`Service missingauthz.v1 definition is missing required property "authz"`,
},
{
"missingtoken.v1",
nil,
`Service missingtoken.v1 definition is missing required property "token"`,
},
{
"passwordmissingauthz.v1",
&OAuthClient{
ID: "passwordmissingauthz",
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("password"),
},
"",
},
{
"absolute.v1",
&OAuthClient{
ID: "absolute",
AuthorizationURL: mustURL(t, "http://example.net/foo/authz"),
TokenURL: mustURL(t, "http://example.net/foo/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"absolutewithport.v1",
&OAuthClient{
ID: "absolutewithport",
AuthorizationURL: mustURL(t, "http://example.net:8000/foo/authz"),
TokenURL: mustURL(t, "http://example.net:8000/foo/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"relative.v1",
&OAuthClient{
ID: "relative",
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"rootrelative.v1",
&OAuthClient{
ID: "rootrelative",
AuthorizationURL: mustURL(t, "https://example.com/authz"),
TokenURL: mustURL(t, "https://example.com/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"protorelative.v1",
&OAuthClient{
ID: "protorelative",
AuthorizationURL: mustURL(t, "https://example.net/authz"),
TokenURL: mustURL(t, "https://example.net/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"nothttp.v1",
nil,
"Failed to parse authorization URL: unsupported scheme ftp",
},
{
"invalidauthz.v1",
nil,
"Failed to parse authorization URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon",
},
{
"invalidtoken.v1",
nil,
"Failed to parse token URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon",
},
}
for _, test := range tests {
t.Run(test.ID, func(t *testing.T) {
got, err := host.ServiceOAuthClient(test.ID)
if (err != nil || test.err != "") &&
(err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("unexpected service URL error: %s", err)
}
if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("wrong result\n%s", diff)
}
})
}
}
func TestVersionConstrains(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
t.Run("exact service version is provided", func(t *testing.T) {
portStr, close := testVersionsServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"service": "%s",
"product": "%s",
"minimum": "0.11.8",
"maximum": "0.12.0"
}`)
// Add the requested service and product to the response.
service := path.Base(r.URL.Path)
product := r.URL.Query().Get("product")
resp = []byte(fmt.Sprintf(string(resp), service, product))
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"thingy.v1": "/api/v1/",
"thingy.v2": "/api/v2/",
"versions.v1": "https://localhost" + portStr + "/v1/versions/",
},
}
expected := &Constraints{
Service: "thingy.v1",
Product: "terraform",
Minimum: "0.11.8",
Maximum: "0.12.0",
}
actual, err := host.VersionConstraints("thingy.v1", "terraform")
if err != nil {
t.Fatalf("unexpected version constraints error: %s", err)
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("expected %#v, got: %#v", expected, actual)
}
})
t.Run("service provided with different versions", func(t *testing.T) {
portStr, close := testVersionsServer(func(w http.ResponseWriter, r *http.Request) {
resp := []byte(`
{
"service": "%s",
"product": "%s",
"minimum": "0.11.8",
"maximum": "0.12.0"
}`)
// Add the requested service and product to the response.
service := path.Base(r.URL.Path)
product := r.URL.Query().Get("product")
resp = []byte(fmt.Sprintf(string(resp), service, product))
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Length", strconv.Itoa(len(resp)))
w.Write(resp)
})
defer close()
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"thingy.v2": "/api/v2/",
"thingy.v3": "/api/v3/",
"versions.v1": "https://localhost" + portStr + "/v1/versions/",
},
}
expected := &Constraints{
Service: "thingy.v3",
Product: "terraform",
Minimum: "0.11.8",
Maximum: "0.12.0",
}
actual, err := host.VersionConstraints("thingy.v1", "terraform")
if err != nil {
t.Fatalf("unexpected version constraints error: %s", err)
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("expected %#v, got: %#v", expected, actual)
}
})
t.Run("service not provided", func(t *testing.T) {
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"versions.v1": "https://localhost/v1/versions/",
},
}
_, err := host.VersionConstraints("thingy.v1", "terraform")
if _, ok := err.(*ErrServiceNotProvided); !ok {
t.Fatalf("expected service not provided error, got: %v", err)
}
})
t.Run("versions service returns a 404", func(t *testing.T) {
portStr, close := testVersionsServer(nil)
defer close()
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"thingy.v1": "/api/v1/",
"versions.v1": "https://localhost" + portStr + "/v1/non-existent/",
},
}
_, err := host.VersionConstraints("thingy.v1", "terraform")
if _, ok := err.(*ErrNoVersionConstraints); !ok {
t.Fatalf("expected service not provided error, got: %v", err)
}
})
t.Run("checkpoint is disabled", func(t *testing.T) {
if err := os.Setenv("CHECKPOINT_DISABLE", "1"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer os.Unsetenv("CHECKPOINT_DISABLE")
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"thingy.v1": "/api/v1/",
"versions.v1": "https://localhost/v1/versions/",
},
}
_, err := host.VersionConstraints("thingy.v1", "terraform")
if _, ok := err.(*ErrNoVersionConstraints); !ok {
t.Fatalf("expected service not provided error, got: %v", err)
}
})
t.Run("versions service not discovered", func(t *testing.T) {
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"thingy.v1": "/api/v1/",
},
}
_, err := host.VersionConstraints("thingy.v1", "terraform")
if _, ok := err.(*ErrServiceNotProvided); !ok {
t.Fatalf("expected service not provided error, got: %v", err)
}
})
t.Run("versions service version not discovered", func(t *testing.T) {
host := Host{
discoURL: baseURL,
hostname: "test-server",
transport: httpTransport,
services: map[string]interface{}{
"thingy.v1": "/api/v1/",
"versions.v2": "https://localhost/v2/versions/",
},
}
_, err := host.VersionConstraints("thingy.v1", "terraform")
if _, ok := err.(*ErrVersionNotSupported); !ok {
t.Fatalf("expected service not provided error, got: %v", err)
}
})
}
func testVersionsServer(h func(w http.ResponseWriter, r *http.Request)) (portStr string, close func()) {
server := httptest.NewTLSServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Test server always returns 404 if the URL isn't what we expect
if !strings.HasPrefix(r.URL.Path, "/v1/versions/") {
w.WriteHeader(404)
w.Write([]byte("not found"))
return
}
// If the URL is correct then the given hander decides the response
h(w, r)
},
))
serverURL, _ := url.Parse(server.URL)
portStr = serverURL.Port()
if portStr != "" {
portStr = ":" + portStr
}
close = func() {
server.Close()
}
return portStr, close
}

View File

@ -1,178 +0,0 @@
package disco
import (
"fmt"
"net/url"
"strings"
"golang.org/x/oauth2"
)
// OAuthClient represents an OAuth client configuration, which is used for
// unusual services that require an entire OAuth client configuration as part
// of their service discovery, rather than just a URL.
type OAuthClient struct {
// ID is the identifier for the client, to be used as "client_id" in
// OAuth requests.
ID string
// Authorization URL is the URL of the authorization endpoint that must
// be used for this OAuth client, as defined in the OAuth2 specifications.
//
// Not all grant types use the authorization endpoint, so it may be omitted
// if none of the grant types in SupportedGrantTypes require it.
AuthorizationURL *url.URL
// Token URL is the URL of the token endpoint that must be used for this
// OAuth client, as defined in the OAuth2 specifications.
//
// Not all grant types use the token endpoint, so it may be omitted
// if none of the grant types in SupportedGrantTypes require it.
TokenURL *url.URL
// MinPort and MaxPort define a range of TCP ports on localhost that this
// client is able to use as redirect_uri in an authorization request.
// Terraform will select a port from this range for the temporary HTTP
// server it creates to receive the authorization response, giving
// a URL like http://localhost:NNN/ where NNN is the selected port number.
//
// Terraform will reject any port numbers in this range less than 1024,
// to respect the common convention (enforced on some operating systems)
// that lower port numbers are reserved for "privileged" services.
MinPort, MaxPort uint16
// SupportedGrantTypes is a set of the grant types that the client may
// choose from. This includes an entry for each distinct type advertised
// by the server, even if a particular keyword is not supported by the
// current version of Terraform.
SupportedGrantTypes OAuthGrantTypeSet
}
// Endpoint returns an oauth2.Endpoint value ready to be used with the oauth2
// library, representing the URLs from the receiver.
func (c *OAuthClient) Endpoint() oauth2.Endpoint {
ep := oauth2.Endpoint{
// We don't actually auth because we're not a server-based OAuth client,
// so this instead just means that we include client_id as an argument
// in our requests.
AuthStyle: oauth2.AuthStyleInParams,
}
if c.AuthorizationURL != nil {
ep.AuthURL = c.AuthorizationURL.String()
}
if c.TokenURL != nil {
ep.TokenURL = c.TokenURL.String()
}
return ep
}
// OAuthGrantType is an enumeration of grant type strings that a host can
// advertise support for.
//
// Values of this type don't necessarily match with a known constant of the
// type, because they may represent grant type keywords defined in a later
// version of Terraform which this version doesn't yet know about.
type OAuthGrantType string
const (
// OAuthAuthzCodeGrant represents an authorization code grant, as
// defined in IETF RFC 6749 section 4.1.
OAuthAuthzCodeGrant = OAuthGrantType("authz_code")
// OAuthOwnerPasswordGrant represents a resource owner password
// credentials grant, as defined in IETF RFC 6749 section 4.3.
OAuthOwnerPasswordGrant = OAuthGrantType("password")
)
// UsesAuthorizationEndpoint returns true if the receiving grant type makes
// use of the authorization endpoint from the client configuration, and thus
// if the authorization endpoint ought to be required.
func (t OAuthGrantType) UsesAuthorizationEndpoint() bool {
switch t {
case OAuthAuthzCodeGrant:
return true
case OAuthOwnerPasswordGrant:
return false
default:
// We'll default to false so that we don't impose any requirements
// on any grant type keywords that might be defined for future
// versions of Terraform.
return false
}
}
// UsesTokenEndpoint returns true if the receiving grant type makes
// use of the token endpoint from the client configuration, and thus
// if the authorization endpoint ought to be required.
func (t OAuthGrantType) UsesTokenEndpoint() bool {
switch t {
case OAuthAuthzCodeGrant:
return true
case OAuthOwnerPasswordGrant:
return true
default:
// We'll default to false so that we don't impose any requirements
// on any grant type keywords that might be defined for future
// versions of Terraform.
return false
}
}
// OAuthGrantTypeSet represents a set of OAuthGrantType values.
type OAuthGrantTypeSet map[OAuthGrantType]struct{}
// NewOAuthGrantTypeSet constructs a new grant type set from the given list
// of grant type keyword strings. Any duplicates in the list are ignored.
func NewOAuthGrantTypeSet(keywords ...string) OAuthGrantTypeSet {
ret := make(OAuthGrantTypeSet, len(keywords))
for _, kw := range keywords {
ret[OAuthGrantType(kw)] = struct{}{}
}
return ret
}
// Has returns true if the given grant type is in the receiving set.
func (s OAuthGrantTypeSet) Has(t OAuthGrantType) bool {
_, ok := s[t]
return ok
}
// RequiresAuthorizationEndpoint returns true if any of the grant types in
// the set are known to require an authorization endpoint.
func (s OAuthGrantTypeSet) RequiresAuthorizationEndpoint() bool {
for t := range s {
if t.UsesAuthorizationEndpoint() {
return true
}
}
return false
}
// RequiresTokenEndpoint returns true if any of the grant types in
// the set are known to require a token endpoint.
func (s OAuthGrantTypeSet) RequiresTokenEndpoint() bool {
for t := range s {
if t.UsesTokenEndpoint() {
return true
}
}
return false
}
// GoString implements fmt.GoStringer.
func (s OAuthGrantTypeSet) GoString() string {
var buf strings.Builder
i := 0
buf.WriteString("disco.NewOAuthGrantTypeSet(")
for t := range s {
if i > 0 {
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%q", string(t))
i++
}
buf.WriteString(")")
return buf.String()
}

View File

@ -1,69 +0,0 @@
package svchost
import (
"strings"
)
// A labelIter allows iterating over domain name labels.
//
// This type is copied from golang.org/x/net/idna, where it is used
// to segment hostnames into their separate labels for analysis. We use
// it for the same purpose here, in ForComparison.
type labelIter struct {
orig string
slice []string
curStart int
curEnd int
i int
}
func (l *labelIter) reset() {
l.curStart = 0
l.curEnd = 0
l.i = 0
}
func (l *labelIter) done() bool {
return l.curStart >= len(l.orig)
}
func (l *labelIter) result() string {
if l.slice != nil {
return strings.Join(l.slice, ".")
}
return l.orig
}
func (l *labelIter) label() string {
if l.slice != nil {
return l.slice[l.i]
}
p := strings.IndexByte(l.orig[l.curStart:], '.')
l.curEnd = l.curStart + p
if p == -1 {
l.curEnd = len(l.orig)
}
return l.orig[l.curStart:l.curEnd]
}
// next sets the value to the next label. It skips the last label if it is empty.
func (l *labelIter) next() {
l.i++
if l.slice != nil {
if l.i >= len(l.slice) || l.i == len(l.slice)-1 && l.slice[l.i] == "" {
l.curStart = len(l.orig)
}
} else {
l.curStart = l.curEnd + 1
if l.curStart == len(l.orig)-1 && l.orig[l.curStart] == '.' {
l.curStart = len(l.orig)
}
}
}
func (l *labelIter) set(s string) {
if l.slice == nil {
l.slice = strings.Split(l.orig, ".")
}
l.slice[l.i] = s
}

View File

@ -1,207 +0,0 @@
// Package svchost deals with the representations of the so-called "friendly
// hostnames" that we use to represent systems that provide Terraform-native
// remote services, such as module registry, remote operations, etc.
//
// Friendly hostnames are specified such that, as much as possible, they
// are consistent with how web browsers think of hostnames, so that users
// can bring their intuitions about how hostnames behave when they access
// a Terraform Enterprise instance's web UI (or indeed any other website)
// and have this behave in a similar way.
package svchost
import (
"errors"
"fmt"
"strconv"
"strings"
"golang.org/x/net/idna"
)
// Hostname is specialized name for string that indicates that the string
// has been converted to (or was already in) the storage and comparison form.
//
// Hostname values are not suitable for display in the user-interface. Use
// the ForDisplay method to obtain a form suitable for display in the UI.
//
// Unlike user-supplied hostnames, strings of type Hostname (assuming they
// were constructed by a function within this package) can be compared for
// equality using the standard Go == operator.
type Hostname string
// acePrefix is the ASCII Compatible Encoding prefix, used to indicate that
// a domain name label is in "punycode" form.
const acePrefix = "xn--"
// displayProfile is a very liberal idna profile that we use to do
// normalization for display without imposing validation rules.
var displayProfile = idna.New(
idna.MapForLookup(),
idna.Transitional(true),
)
// ForDisplay takes a user-specified hostname and returns a normalized form of
// it suitable for display in the UI.
//
// If the input is so invalid that no normalization can be performed then
// this will return the input, assuming that the caller still wants to
// display _something_. This function is, however, more tolerant than the
// other functions in this package and will make a best effort to prepare
// _any_ given hostname for display.
//
// For validation, use either IsValid (for explicit validation) or
// ForComparison (which implicitly validates, returning an error if invalid).
func ForDisplay(given string) string {
var portPortion string
if colonPos := strings.Index(given, ":"); colonPos != -1 {
given, portPortion = given[:colonPos], given[colonPos:]
}
portPortion, _ = normalizePortPortion(portPortion)
ascii, err := displayProfile.ToASCII(given)
if err != nil {
return given + portPortion
}
display, err := displayProfile.ToUnicode(ascii)
if err != nil {
return given + portPortion
}
return display + portPortion
}
// IsValid returns true if the given user-specified hostname is a valid
// service hostname.
//
// Validity is determined by complying with the RFC 5891 requirements for
// names that are valid for domain lookup (section 5), with the additional
// requirement that user-supplied forms must not _already_ contain
// Punycode segments.
func IsValid(given string) bool {
_, err := ForComparison(given)
return err == nil
}
// ForComparison takes a user-specified hostname and returns a normalized
// form of it suitable for storage and comparison. The result is not suitable
// for display to end-users because it uses Punycode to represent non-ASCII
// characters, and this form is unreadable for non-ASCII-speaking humans.
//
// The result is typed as Hostname -- a specialized name for string -- so that
// other APIs can make it clear within the type system whether they expect a
// user-specified or display-form hostname or a value already normalized for
// comparison.
//
// The returned Hostname is not valid if the returned error is non-nil.
func ForComparison(given string) (Hostname, error) {
var portPortion string
if colonPos := strings.Index(given, ":"); colonPos != -1 {
given, portPortion = given[:colonPos], given[colonPos:]
}
var err error
portPortion, err = normalizePortPortion(portPortion)
if err != nil {
return Hostname(""), err
}
if given == "" {
return Hostname(""), fmt.Errorf("empty string is not a valid hostname")
}
// First we'll apply our additional constraint that Punycode must not
// be given directly by the user. This is not an IDN specification
// requirement, but we prohibit it to force users to use human-readable
// hostname forms within Terraform configuration.
labels := labelIter{orig: given}
for ; !labels.done(); labels.next() {
label := labels.label()
if label == "" {
return Hostname(""), fmt.Errorf(
"hostname contains empty label (two consecutive periods)",
)
}
if strings.HasPrefix(label, acePrefix) {
return Hostname(""), fmt.Errorf(
"hostname label %q specified in punycode format; service hostnames must be given in unicode",
label,
)
}
}
result, err := idna.Lookup.ToASCII(given)
if err != nil {
return Hostname(""), err
}
return Hostname(result + portPortion), nil
}
// ForDisplay returns a version of the receiver that is appropriate for display
// in the UI. This includes converting any punycode labels to their
// corresponding Unicode characters.
//
// A round-trip through ForComparison and this ForDisplay method does not
// guarantee the same result as calling this package's top-level ForDisplay
// function, since a round-trip through the Hostname type implies stricter
// handling than we do when doing basic display-only processing.
func (h Hostname) ForDisplay() string {
given := string(h)
var portPortion string
if colonPos := strings.Index(given, ":"); colonPos != -1 {
given, portPortion = given[:colonPos], given[colonPos:]
}
// We don't normalize the port portion here because we assume it's
// already been normalized on the way in.
result, err := idna.Lookup.ToUnicode(given)
if err != nil {
// Should never happen, since type Hostname indicates that a string
// passed through our validation rules.
panic(fmt.Errorf("ForDisplay called on invalid Hostname: %s", err))
}
return result + portPortion
}
func (h Hostname) String() string {
return string(h)
}
func (h Hostname) GoString() string {
return fmt.Sprintf("svchost.Hostname(%q)", string(h))
}
// normalizePortPortion attempts to normalize the "port portion" of a hostname,
// which begins with the first colon in the hostname and should be followed
// by a string of decimal digits.
//
// If the port portion is valid, a normalized version of it is returned along
// with a nil error.
//
// If the port portion is invalid, the input string is returned verbatim along
// with a non-nil error.
//
// An empty string is a valid port portion representing the absence of a port.
// If non-empty, the first character must be a colon.
func normalizePortPortion(s string) (string, error) {
if s == "" {
return s, nil
}
if s[0] != ':' {
// should never happen, since caller tends to guarantee the presence
// of a colon due to how it's extracted from the string.
return s, errors.New("port portion is missing its initial colon")
}
numStr := s[1:]
num, err := strconv.Atoi(numStr)
if err != nil {
return s, errors.New("port portion contains non-digit characters")
}
if num == 443 {
return "", nil // ":443" is the default
}
if num > 65535 {
return s, errors.New("port number is greater than 65535")
}
return fmt.Sprintf(":%d", num), nil
}

View File

@ -1,218 +0,0 @@
package svchost
import "testing"
func TestForDisplay(t *testing.T) {
tests := []struct {
Input string
Want string
}{
{
"",
"",
},
{
"example.com",
"example.com",
},
{
"invalid",
"invalid",
},
{
"localhost",
"localhost",
},
{
"localhost:1211",
"localhost:1211",
},
{
"HashiCorp.com",
"hashicorp.com",
},
{
"Испытание.com",
"испытание.com",
},
{
"münchen.de", // this is a precomposed u with diaeresis
"münchen.de", // this is a precomposed u with diaeresis
},
{
"münchen.de", // this is a separate u and combining diaeresis
"münchen.de", // this is a precomposed u with diaeresis
},
{
"example.com:443",
"example.com",
},
{
"example.com:81",
"example.com:81",
},
{
"example.com:boo",
"example.com:boo", // invalid, but tolerated for display purposes
},
{
"example.com:boo:boo",
"example.com:boo:boo", // invalid, but tolerated for display purposes
},
{
"example.com:081",
"example.com:81",
},
}
for _, test := range tests {
t.Run(test.Input, func(t *testing.T) {
got := ForDisplay(test.Input)
if got != test.Want {
t.Errorf("wrong result\ninput: %s\ngot: %s\nwant: %s", test.Input, got, test.Want)
}
})
}
}
func TestForComparison(t *testing.T) {
tests := []struct {
Input string
Want string
Err bool
}{
{
"",
"",
true,
},
{
"example.com",
"example.com",
false,
},
{
"example.com:443",
"example.com",
false,
},
{
"example.com:81",
"example.com:81",
false,
},
{
"example.com:081",
"example.com:81",
false,
},
{
"invalid",
"invalid",
false, // the "invalid" TLD is, confusingly, a valid hostname syntactically
},
{
"localhost", // supported for local testing only
"localhost",
false,
},
{
"localhost:1211", // supported for local testing only
"localhost:1211",
false,
},
{
"HashiCorp.com",
"hashicorp.com",
false,
},
{
"1example.com",
"1example.com",
false,
},
{
"Испытание.com",
"xn--80akhbyknj4f.com",
false,
},
{
"münchen.de", // this is a precomposed u with diaeresis
"xn--mnchen-3ya.de",
false,
},
{
"münchen.de", // this is a separate u and combining diaeresis
"xn--mnchen-3ya.de",
false,
},
{
"blah..blah",
"",
true,
},
{
"example.com:boo",
"",
true,
},
{
"example.com:80:boo",
"",
true,
},
}
for _, test := range tests {
t.Run(test.Input, func(t *testing.T) {
got, err := ForComparison(test.Input)
if (err != nil) != test.Err {
if test.Err {
t.Error("unexpected success; want error")
} else {
t.Errorf("unexpected error; want success\nerror: %s", err)
}
}
if string(got) != test.Want {
t.Errorf("wrong result\ninput: %s\ngot: %s\nwant: %s", test.Input, got, test.Want)
}
})
}
}
func TestHostnameForDisplay(t *testing.T) {
tests := []struct {
Input string
Want string
}{
{
"example.com",
"example.com",
},
{
"example.com:81",
"example.com:81",
},
{
"xn--80akhbyknj4f.com",
"испытание.com",
},
{
"xn--80akhbyknj4f.com:8080",
"испытание.com:8080",
},
{
"xn--mnchen-3ya.de",
"münchen.de", // this is a precomposed u with diaeresis
},
}
for _, test := range tests {
t.Run(test.Input, func(t *testing.T) {
got := Hostname(test.Input).ForDisplay()
if got != test.Want {
t.Errorf("wrong result\ninput: %s\ngot: %s\nwant: %s", test.Input, got, test.Want)
}
})
}
}