terraform/vendor/github.com/jen20/riviera/azure/request.go

217 lines
5.3 KiB
Go
Raw Normal View History

package azure
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"time"
"github.com/hashicorp/go-retryablehttp"
"github.com/mitchellh/mapstructure"
)
type Request struct {
URI *string `json:"-"`
location *string `json:"location,omitempty"`
tags *map[string]*string `json:"tags,omitempty"`
etag *string `json:"etag,omitempty"`
Command ApiCall `json:"properties,omitempty"`
client *Client
}
func readLocation(req interface{}) (string, bool) {
var value reflect.Value
if reflect.ValueOf(req).Kind() == reflect.Ptr {
value = reflect.ValueOf(req).Elem()
} else {
value = reflect.ValueOf(req)
}
for i := 0; i < value.NumField(); i++ { // iterates through every struct type field
tag := value.Type().Field(i).Tag // returns the tag string
if tag.Get("riviera") == "location" {
return value.Field(i).String(), true
}
}
return "", false
}
func readTags(req interface{}) (map[string]*string, bool) {
var value reflect.Value
if reflect.ValueOf(req).Kind() == reflect.Ptr {
value = reflect.ValueOf(req).Elem()
} else {
value = reflect.ValueOf(req)
}
for i := 0; i < value.NumField(); i++ { // iterates through every struct type field
tag := value.Type().Field(i).Tag // returns the tag string
if tag.Get("riviera") == "tags" {
tags := value.Field(i)
return tags.Interface().(map[string]*string), true
}
}
return make(map[string]*string), false
}
func (request *Request) pollForAsynchronousResponse(acceptedResponse *http.Response) (*http.Response, error) {
var resp *http.Response = acceptedResponse
for {
if resp.StatusCode != http.StatusAccepted {
return resp, nil
}
if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" {
retryTime, err := strconv.Atoi(strings.TrimSpace(retryAfter))
if err != nil {
return nil, err
}
request.client.logger.Printf("[INFO] Polling pausing for %d seconds as per Retry-After header", retryTime)
time.Sleep(time.Duration(retryTime) * time.Second)
}
pollLocation, err := resp.Location()
if err != nil {
return nil, err
}
request.client.logger.Printf("[INFO] Polling %q for operation completion", pollLocation.String())
req, err := retryablehttp.NewRequest("GET", pollLocation.String(), bytes.NewReader([]byte{}))
if err != nil {
return nil, err
}
err = request.client.tokenRequester.addAuthorizationToRequest(req)
if err != nil {
return nil, err
}
resp, err := request.client.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusAccepted {
continue
}
return resp, err
}
}
func (r *Request) Execute() (*Response, error) {
apiInfo := r.Command.ApiInfo()
var urlString string
// Base URL should already be validated by now so Parse is safe without error handling
urlObj, _ := url.Parse(r.client.BaseUrl)
// Determine whether to use the URLPathFunc or the URI explictly set in the request
if r.URI == nil {
urlObj.Path = fmt.Sprintf("/subscriptions/%s/%s", r.client.subscriptionID, strings.TrimPrefix(apiInfo.URLPathFunc(), "/"))
urlString = urlObj.String()
} else {
urlObj.Path = *r.URI
urlString = urlObj.String()
}
// Encode the request body if necessary
body := bytes.NewReader([]byte{})
if apiInfo.HasBody() {
bodyStruct := struct {
Location *string `json:"location,omitempty"`
Tags *map[string]*string `json:"tags,omitempty"`
Properties interface{} `json:"properties"`
}{
Properties: r.Command,
}
if location, hasLocation := readLocation(r.Command); hasLocation {
bodyStruct.Location = &location
}
if tags, hasTags := readTags(r.Command); hasTags {
if len(tags) > 0 {
bodyStruct.Tags = &tags
}
}
jsonEncodedRequest, err := json.Marshal(bodyStruct)
if err != nil {
return nil, err
}
body = bytes.NewReader(jsonEncodedRequest)
}
// Create an HTTP request
req, err := retryablehttp.NewRequest(apiInfo.Method, urlString, body)
if err != nil {
return nil, err
}
query := req.URL.Query()
query.Set("api-version", apiInfo.ApiVersion)
req.URL.RawQuery = query.Encode()
if apiInfo.HasBody() {
req.Header.Add("Content-Type", "application/json")
}
err = r.client.tokenRequester.addAuthorizationToRequest(req)
if err != nil {
return nil, err
}
httpResponse, err := r.client.httpClient.Do(req)
if err != nil {
return nil, err
}
// This is safe to use for every request: we check for it being http.StatusAccepted
httpResponse, err = r.pollForAsynchronousResponse(httpResponse)
if err != nil {
return nil, err
}
var responseObj interface{}
var errorObj *Error
if isSuccessCode(httpResponse.StatusCode) {
responseObj = apiInfo.ResponseTypeFunc()
// The response factory func returns nil as a signal that there is no body
if responseObj != nil {
responseMap, err := unmarshalFlattenPropertiesAndClose(httpResponse)
if err != nil {
return nil, err
}
err = mapstructure.WeakDecode(responseMap, responseObj)
if err != nil {
return nil, err
}
}
} else {
responseMap, err := unmarshalFlattenErrorAndClose(httpResponse)
err = mapstructure.WeakDecode(responseMap, &errorObj)
if err != nil {
return nil, err
}
errorObj.StatusCode = httpResponse.StatusCode
}
return &Response{
HTTP: httpResponse,
Parsed: responseObj,
Error: errorObj,
}, nil
}