svchost/disco: Allow oauth client services to specify grant types

Previously we just assumed support for the authorization code grant type,
but now we'll allow the host to declare which grant types it supports
to allow for more flexibility in host login implementations. We may extend
the set of supported grant types in future.
This commit is contained in:
Martin Atkins 2019-08-07 16:30:56 -07:00
parent 5590efcd33
commit 31a9790080
3 changed files with 391 additions and 3 deletions

View File

@ -166,7 +166,30 @@ func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) {
return nil, fmt.Errorf("Service %s must be declared with an object value in the service discovery document", id)
}
ret := &OAuthClient{}
var grantTypes OAuthGrantTypeSet
if rawGTs, ok := raw["grant_types"]; ok {
if gts, ok := rawGTs.([]interface{}); ok {
var kws []string
for _, gtI := range gts {
gt, ok := gtI.(string)
if !ok {
// We'll ignore this so that we can potentially introduce
// other types into this array later if we need to.
continue
}
kws = append(kws, gt)
}
grantTypes = NewOAuthGrantTypeSet(kws...)
} else {
return nil, fmt.Errorf("Service %s is defined with invalid grant_types property: must be an array of grant type strings", id)
}
} else {
grantTypes = NewOAuthGrantTypeSet("authz_code")
}
ret := &OAuthClient{
SupportedGrantTypes: grantTypes,
}
if clientIDStr, ok := raw["client"].(string); ok {
ret.ID = clientIDStr
} else {
@ -179,7 +202,9 @@ func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) {
}
ret.AuthorizationURL = u
} else {
return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id)
if grantTypes.RequiresAuthorizationEndpoint() {
return nil, fmt.Errorf("Service %s definition is missing required property \"authz\"", id)
}
}
if urlStr, ok := raw["token"].(string); ok {
u, err := h.parseURL(urlStr)
@ -188,7 +213,9 @@ func (h *Host) ServiceOAuthClient(id string) (*OAuthClient, error) {
}
ret.TokenURL = u
} else {
return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id)
if grantTypes.RequiresTokenEndpoint() {
return nil, fmt.Errorf("Service %s definition is missing required property \"token\"", id)
}
}
if portsRaw, ok := raw["ports"].([]interface{}); ok {
if len(portsRaw) != 2 {

View File

@ -11,6 +11,8 @@ import (
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestHostServiceURL(t *testing.T) {
@ -69,6 +71,242 @@ func TestHostServiceURL(t *testing.T) {
}
}
func TestHostServiceOAuthClient(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json")
host := Host{
discoURL: baseURL,
hostname: "test-server",
services: map[string]interface{}{
"explicitgranttype.v1": map[string]interface{}{
"client": "explicitgranttype",
"authz": "./authz",
"token": "./token",
"grant_types": []interface{}{"authz_code", "password", "tbd"},
},
"customports.v1": map[string]interface{}{
"client": "customports",
"authz": "./authz",
"token": "./token",
"ports": []interface{}{1025, 1026},
},
"invalidports.v1": map[string]interface{}{
"client": "invalidports",
"authz": "./authz",
"token": "./token",
"ports": []interface{}{1, 65535},
},
"missingauthz.v1": map[string]interface{}{
"client": "missingauthz",
"token": "./token",
},
"missingtoken.v1": map[string]interface{}{
"client": "missingtoken",
"authz": "./authz",
},
"passwordmissingauthz.v1": map[string]interface{}{
"client": "passwordmissingauthz",
"token": "./token",
"grant_types": []interface{}{"password"},
},
"absolute.v1": map[string]interface{}{
"client": "absolute",
"authz": "http://example.net/foo/authz",
"token": "http://example.net/foo/token",
},
"absolutewithport.v1": map[string]interface{}{
"client": "absolutewithport",
"authz": "http://example.net:8000/foo/authz",
"token": "http://example.net:8000/foo/token",
},
"relative.v1": map[string]interface{}{
"client": "relative",
"authz": "./authz",
"token": "./token",
},
"rootrelative.v1": map[string]interface{}{
"client": "rootrelative",
"authz": "/authz",
"token": "/token",
},
"protorelative.v1": map[string]interface{}{
"client": "protorelative",
"authz": "//example.net/authz",
"token": "//example.net/token",
},
"nothttp.v1": map[string]interface{}{
"client": "nothttp",
"authz": "ftp://127.0.0.1/pub/authz",
"token": "ftp://127.0.0.1/pub/token",
},
"invalidauthz.v1": map[string]interface{}{
"client": "invalidauthz",
"authz": "***not A URL at all!:/<@@@@>***",
"token": "/foo",
},
"invalidtoken.v1": map[string]interface{}{
"client": "invalidauthz",
"authz": "/foo",
"token": "***not A URL at all!:/<@@@@>***",
},
},
}
mustURL := func(t *testing.T, s string) *url.URL {
t.Helper()
u, err := url.Parse(s)
if err != nil {
t.Fatalf("invalid wanted URL %s in test case: %s", s, err)
}
return u
}
tests := []struct {
ID string
want *OAuthClient
err string
}{
{
"explicitgranttype.v1",
&OAuthClient{
ID: "explicitgranttype",
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code", "password", "tbd"),
},
"",
},
{
"customports.v1",
&OAuthClient{
ID: "customports",
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1025,
MaxPort: 1026,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"invalidports.v1",
nil,
`Invalid "ports" definition for service invalidports.v1: both ports must be whole numbers between 1024 and 65535`,
},
{
"missingauthz.v1",
nil,
`Service missingauthz.v1 definition is missing required property "authz"`,
},
{
"missingtoken.v1",
nil,
`Service missingtoken.v1 definition is missing required property "token"`,
},
{
"passwordmissingauthz.v1",
&OAuthClient{
ID: "passwordmissingauthz",
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("password"),
},
"",
},
{
"absolute.v1",
&OAuthClient{
ID: "absolute",
AuthorizationURL: mustURL(t, "http://example.net/foo/authz"),
TokenURL: mustURL(t, "http://example.net/foo/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"absolutewithport.v1",
&OAuthClient{
ID: "absolutewithport",
AuthorizationURL: mustURL(t, "http://example.net:8000/foo/authz"),
TokenURL: mustURL(t, "http://example.net:8000/foo/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"relative.v1",
&OAuthClient{
ID: "relative",
AuthorizationURL: mustURL(t, "https://example.com/disco/authz"),
TokenURL: mustURL(t, "https://example.com/disco/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"rootrelative.v1",
&OAuthClient{
ID: "rootrelative",
AuthorizationURL: mustURL(t, "https://example.com/authz"),
TokenURL: mustURL(t, "https://example.com/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"protorelative.v1",
&OAuthClient{
ID: "protorelative",
AuthorizationURL: mustURL(t, "https://example.net/authz"),
TokenURL: mustURL(t, "https://example.net/token"),
MinPort: 1024,
MaxPort: 65535,
SupportedGrantTypes: NewOAuthGrantTypeSet("authz_code"),
},
"",
},
{
"nothttp.v1",
nil,
"Failed to parse authorization URL: unsupported scheme ftp",
},
{
"invalidauthz.v1",
nil,
"Failed to parse authorization URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon",
},
{
"invalidtoken.v1",
nil,
"Failed to parse token URL: parse ***not A URL at all!:/<@@@@>***: first path segment in URL cannot contain colon",
},
}
for _, test := range tests {
t.Run(test.ID, func(t *testing.T) {
got, err := host.ServiceOAuthClient(test.ID)
if (err != nil || test.err != "") &&
(err == nil || !strings.Contains(err.Error(), test.err)) {
t.Fatalf("unexpected service URL error: %s", err)
}
if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("wrong result\n%s", diff)
}
})
}
}
func TestVersionConstrains(t *testing.T) {
baseURL, _ := url.Parse("https://example.com/disco/foo.json")

View File

@ -1,7 +1,9 @@
package disco
import (
"fmt"
"net/url"
"strings"
"golang.org/x/oauth2"
)
@ -16,10 +18,16 @@ type OAuthClient struct {
// Authorization URL is the URL of the authorization endpoint that must
// be used for this OAuth client, as defined in the OAuth2 specifications.
//
// Not all grant types use the authorization endpoint, so it may be omitted
// if none of the grant types in SupportedGrantTypes require it.
AuthorizationURL *url.URL
// Token URL is the URL of the token endpoint that must be used for this
// OAuth client, as defined in the OAuth2 specifications.
//
// Not all grant types use the token endpoint, so it may be omitted
// if none of the grant types in SupportedGrantTypes require it.
TokenURL *url.URL
// MinPort and MaxPort define a range of TCP ports on localhost that this
@ -32,6 +40,12 @@ type OAuthClient struct {
// to respect the common convention (enforced on some operating systems)
// that lower port numbers are reserved for "privileged" services.
MinPort, MaxPort uint16
// SupportedGrantTypes is a set of the grant types that the client may
// choose from. This includes an entry for each distinct type advertised
// by the server, even if a particular keyword is not supported by the
// current version of Terraform.
SupportedGrantTypes OAuthGrantTypeSet
}
// Endpoint returns an oauth2.Endpoint value ready to be used with the oauth2
@ -47,3 +61,112 @@ func (c *OAuthClient) Endpoint() oauth2.Endpoint {
AuthStyle: oauth2.AuthStyleInParams,
}
}
// OAuthGrantType is an enumeration of grant type strings that a host can
// advertise support for.
//
// Values of this type don't necessarily match with a known constant of the
// type, because they may represent grant type keywords defined in a later
// version of Terraform which this version doesn't yet know about.
type OAuthGrantType string
const (
// OAuthAuthzCodeGrant represents an authorization code grant, as
// defined in IETF RFC 6749 section 4.1.
OAuthAuthzCodeGrant = OAuthGrantType("authz_code")
// OAuthOwnerPasswordGrant represents a resource owner password
// credentials grant, as defined in IETF RFC 6749 section 4.3.
OAuthOwnerPasswordGrant = OAuthGrantType("password")
)
// UsesAuthorizationEndpoint returns true if the receiving grant type makes
// use of the authorization endpoint from the client configuration, and thus
// if the authorization endpoint ought to be required.
func (t OAuthGrantType) UsesAuthorizationEndpoint() bool {
switch t {
case OAuthAuthzCodeGrant:
return true
case OAuthOwnerPasswordGrant:
return false
default:
// We'll default to false so that we don't impose any requirements
// on any grant type keywords that might be defined for future
// versions of Terraform.
return false
}
}
// UsesTokenEndpoint returns true if the receiving grant type makes
// use of the token endpoint from the client configuration, and thus
// if the authorization endpoint ought to be required.
func (t OAuthGrantType) UsesTokenEndpoint() bool {
switch t {
case OAuthAuthzCodeGrant:
return true
case OAuthOwnerPasswordGrant:
return true
default:
// We'll default to false so that we don't impose any requirements
// on any grant type keywords that might be defined for future
// versions of Terraform.
return false
}
}
// OAuthGrantTypeSet represents a set of OAuthGrantType values.
type OAuthGrantTypeSet map[OAuthGrantType]struct{}
// NewOAuthGrantTypeSet constructs a new grant type set from the given list
// of grant type keyword strings. Any duplicates in the list are ignored.
func NewOAuthGrantTypeSet(keywords ...string) OAuthGrantTypeSet {
ret := make(OAuthGrantTypeSet, len(keywords))
for _, kw := range keywords {
ret[OAuthGrantType(kw)] = struct{}{}
}
return ret
}
// Has returns true if the given grant type is in the receiving set.
func (s OAuthGrantTypeSet) Has(t OAuthGrantType) bool {
_, ok := s[t]
return ok
}
// RequiresAuthorizationEndpoint returns true if any of the grant types in
// the set are known to require an authorization endpoint.
func (s OAuthGrantTypeSet) RequiresAuthorizationEndpoint() bool {
for t := range s {
if t.UsesAuthorizationEndpoint() {
return true
}
}
return false
}
// RequiresTokenEndpoint returns true if any of the grant types in
// the set are known to require a token endpoint.
func (s OAuthGrantTypeSet) RequiresTokenEndpoint() bool {
for t := range s {
if t.UsesTokenEndpoint() {
return true
}
}
return false
}
// GoString implements fmt.GoStringer.
func (s OAuthGrantTypeSet) GoString() string {
var buf strings.Builder
i := 0
buf.WriteString("disco.NewOAuthGrantTypeSet(")
for t := range s {
if i > 0 {
buf.WriteString(", ")
}
fmt.Fprintf(&buf, "%q", string(t))
i++
}
buf.WriteString(")")
return buf.String()
}