[GH-1275] Support for AWS access via IAMs AssumeRole functionality

This commit enables terraform to utilise the assume role functionality
of sts to execute commands with different privileges than the API
keys specified.

Signed-off-by: Ian Duffy <ian@ianduffy.ie>
This commit is contained in:
Ian Duffy 2016-08-27 01:46:41 +01:00 committed by James Nugent
parent 0cf43411f9
commit 767914bbdc
6 changed files with 108 additions and 16 deletions

View File

@ -11,12 +11,14 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
awsCredentials "github.com/aws/aws-sdk-go/aws/credentials" awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-multierror"
) )
func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
@ -92,7 +94,9 @@ func parseAccountIdFromArn(arn string) (string, error) {
// This function is responsible for reading credentials from the // This function is responsible for reading credentials from the
// environment in the case that they're not explicitly specified // environment in the case that they're not explicitly specified
// in the Terraform configuration. // in the Terraform configuration.
func GetCredentials(c *Config) *awsCredentials.Credentials { func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
var errs []error
// build a chain provider, lazy-evaulated by aws-sdk // build a chain provider, lazy-evaulated by aws-sdk
providers := []awsCredentials.Provider{ providers := []awsCredentials.Provider{
&awsCredentials.StaticProvider{Value: awsCredentials.Value{ &awsCredentials.StaticProvider{Value: awsCredentials.Value{
@ -137,7 +141,40 @@ func GetCredentials(c *Config) *awsCredentials.Credentials {
} }
} }
return awsCredentials.NewChainCredentials(providers) if c.RoleArn != "" {
log.Printf("[INFO] attempting to assume role %s", c.RoleArn)
creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`))
} else {
errs = append(errs, fmt.Errorf("Error loading credentials for AWS Provider: %s", err))
}
return nil, &multierror.Error{Errors: errs}
}
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
S3ForcePathStyle: aws.Bool(c.S3ForcePathStyle),
}
stsclient := sts.New(session.New(awsConfig))
providers = []awsCredentials.Provider{&stscreds.AssumeRoleProvider{
Client: stsclient,
RoleARN: c.RoleArn,
}}
}
return awsCredentials.NewChainCredentials(providers), nil
} }
func setOptionalEndpoint(cfg *aws.Config) string { func setOptionalEndpoint(cfg *aws.Config) string {

View File

@ -218,8 +218,13 @@ func TestAWSGetCredentials_shouldError(t *testing.T) {
defer resetEnv() defer resetEnv()
cfg := Config{} cfg := Config{}
c := GetCredentials(&cfg) c, err := GetCredentials(&cfg)
_, err := c.Get() if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error")
}
}
_, err = c.Get()
if awsErr, ok := err.(awserr.Error); ok { if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" { if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error") t.Fatalf("Expected NoCredentialProviders error")
@ -251,10 +256,13 @@ func TestAWSGetCredentials_shouldBeStatic(t *testing.T) {
Token: c.Token, Token: c.Token,
} }
creds := GetCredentials(&cfg) creds, err := GetCredentials(&cfg)
if creds == nil { if creds == nil {
t.Fatalf("Expected a static creds provider to be returned") t.Fatalf("Expected a static creds provider to be returned")
} }
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Error gettings creds: %s", err) t.Fatalf("Error gettings creds: %s", err)
@ -286,11 +294,13 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) {
// An empty config, no key supplied // An empty config, no key supplied
cfg := Config{} cfg := Config{}
creds := GetCredentials(&cfg) creds, err := GetCredentials(&cfg)
if creds == nil { if creds == nil {
t.Fatalf("Expected a static creds provider to be returned") t.Fatalf("Expected a static creds provider to be returned")
} }
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Error gettings creds: %s", err) t.Fatalf("Error gettings creds: %s", err)
@ -335,10 +345,13 @@ func TestAWSGetCredentials_shouldIgnoreIAM(t *testing.T) {
Token: c.Token, Token: c.Token,
} }
creds := GetCredentials(&cfg) creds, err := GetCredentials(&cfg)
if creds == nil { if creds == nil {
t.Fatalf("Expected a static creds provider to be returned") t.Fatalf("Expected a static creds provider to be returned")
} }
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Error gettings creds: %s", err) t.Fatalf("Error gettings creds: %s", err)
@ -362,7 +375,10 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t) ts := invalidAwsEnv(t)
defer ts() defer ts()
creds := GetCredentials(&Config{}) creds, err := GetCredentials(&Config{})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err == nil { if err == nil {
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint") t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
@ -380,7 +396,10 @@ func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t) ts := invalidAwsEnv(t)
defer ts() defer ts()
creds := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"}) creds, err := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Getting static credentials w/ invalid EC2 endpoint failed: %s", err) t.Fatalf("Getting static credentials w/ invalid EC2 endpoint failed: %s", err)
@ -406,10 +425,13 @@ func TestAWSGetCredentials_shouldCatchEC2RoleProvider(t *testing.T) {
ts := awsEnv(t) ts := awsEnv(t)
defer ts() defer ts()
creds := GetCredentials(&Config{}) creds, err := GetCredentials(&Config{})
if creds == nil { if creds == nil {
t.Fatalf("Expected an EC2Role creds provider to be returned") t.Fatalf("Expected an EC2Role creds provider to be returned")
} }
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Expected no error when getting creds: %s", err) t.Fatalf("Expected no error when getting creds: %s", err)
@ -452,10 +474,13 @@ func TestAWSGetCredentials_shouldBeShared(t *testing.T) {
t.Fatalf("Error resetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err) t.Fatalf("Error resetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err)
} }
creds := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()}) creds, err := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()})
if creds == nil { if creds == nil {
t.Fatalf("Expected a provider chain to be returned") t.Fatalf("Expected a provider chain to be returned")
} }
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Error gettings creds: %s", err) t.Fatalf("Error gettings creds: %s", err)
@ -479,10 +504,13 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) {
defer resetEnv() defer resetEnv()
cfg := Config{} cfg := Config{}
creds := GetCredentials(&cfg) creds, err := GetCredentials(&cfg)
if creds == nil { if creds == nil {
t.Fatalf("Expected a static creds provider to be returned") t.Fatalf("Expected a static creds provider to be returned")
} }
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get() v, err := creds.Get()
if err != nil { if err != nil {
t.Fatalf("Error gettings creds: %s", err) t.Fatalf("Error gettings creds: %s", err)

View File

@ -66,6 +66,7 @@ type Config struct {
Profile string Profile string
Token string Token string
Region string Region string
RoleArn string
MaxRetries int MaxRetries int
AllowedAccountIds []interface{} AllowedAccountIds []interface{}
@ -150,7 +151,10 @@ func (c *Config) Client() (interface{}, error) {
client.region = c.Region client.region = c.Region
log.Println("[INFO] Building AWS auth structure") log.Println("[INFO] Building AWS auth structure")
creds := GetCredentials(c) creds, err := GetCredentials(c)
if err != nil {
return nil, &multierror.Error{Errors: errs}
}
// Call Get to check for credential provider. If nothing found, we'll get an // Call Get to check for credential provider. If nothing found, we'll get an
// error, and we can present it nicely to the user // error, and we can present it nicely to the user
cp, err := creds.Get() cp, err := creds.Get()

View File

@ -64,6 +64,13 @@ func Provider() terraform.ResourceProvider {
InputDefault: "us-east-1", InputDefault: "us-east-1",
}, },
"role_arn": &schema.Schema{
Type: schema.TypeString,
Optional: true,
Default: "",
Description: descriptions["role_arn"],
},
"max_retries": &schema.Schema{ "max_retries": &schema.Schema{
Type: schema.TypeInt, Type: schema.TypeInt,
Optional: true, Optional: true,
@ -353,6 +360,8 @@ func init() {
"profile": "The profile for API operations. If not set, the default profile\n" + "profile": "The profile for API operations. If not set, the default profile\n" +
"created with `aws configure` will be used.", "created with `aws configure` will be used.",
"role_arn": "The role to be assumed using the supplied access_key and secret_key",
"shared_credentials_file": "The path to the shared credentials file. If not set\n" + "shared_credentials_file": "The path to the shared credentials file. If not set\n" +
"this defaults to ~/.aws/credentials.", "this defaults to ~/.aws/credentials.",
@ -404,6 +413,7 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
CredsFilename: d.Get("shared_credentials_file").(string), CredsFilename: d.Get("shared_credentials_file").(string),
Token: d.Get("token").(string), Token: d.Get("token").(string),
Region: d.Get("region").(string), Region: d.Get("region").(string),
RoleArn: d.Get("role_arn").(string),
MaxRetries: d.Get("max_retries").(int), MaxRetries: d.Get("max_retries").(int),
DynamoDBEndpoint: d.Get("dynamodb_endpoint").(string), DynamoDBEndpoint: d.Get("dynamodb_endpoint").(string),
KinesisEndpoint: d.Get("kinesis_endpoint").(string), KinesisEndpoint: d.Get("kinesis_endpoint").(string),

View File

@ -60,7 +60,7 @@ func s3Factory(conf map[string]string) (Client, error) {
kmsKeyID := conf["kms_key_id"] kmsKeyID := conf["kms_key_id"]
var errs []error var errs []error
creds := terraformAws.GetCredentials(&terraformAws.Config{ creds, err := terraformAws.GetCredentials(&terraformAws.Config{
AccessKey: conf["access_key"], AccessKey: conf["access_key"],
SecretKey: conf["secret_key"], SecretKey: conf["secret_key"],
Token: conf["token"], Token: conf["token"],
@ -69,7 +69,7 @@ func s3Factory(conf map[string]string) (Client, error) {
}) })
// Call Get to check for credential provider. If nothing found, we'll get an // Call Get to check for credential provider. If nothing found, we'll get an
// error, and we can present it nicely to the user // error, and we can present it nicely to the user
_, err := creds.Get() _, err = creds.Get()
if err != nil { if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" { if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS S3 remote. errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS S3 remote.

View File

@ -111,6 +111,19 @@ You can provide custom metadata API endpoint via `AWS_METADATA_ENDPOINT` variabl
which expects the endpoint URL including the version which expects the endpoint URL including the version
and defaults to `http://169.254.169.254:80/latest`. and defaults to `http://169.254.169.254:80/latest`.
###Assume role
If provided with a role arn, terraform will attempt to assume this role
using the supplied credentials.
Usage:
```
provider "aws" {
role_arn = "arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME"
}
```
## Argument Reference ## Argument Reference
The following arguments are supported in the `provider` block: The following arguments are supported in the `provider` block: