package azure import ( "bytes" "encoding/json" "fmt" "io" "net/http" "net/url" "reflect" "strconv" "strings" "time" "github.com/hashicorp/go-retryablehttp" "github.com/mitchellh/mapstructure" ) type Request struct { URI *string location *string tags *map[string]*string etag *string Command APICall client *Client } func readLocation(req interface{}) (string, bool) { var value reflect.Value if reflect.ValueOf(req).Kind() == reflect.Ptr { value = reflect.ValueOf(req).Elem() } else { value = reflect.ValueOf(req) } for i := 0; i < value.NumField(); i++ { // iterates through every struct type field tag := value.Type().Field(i).Tag // returns the tag string if tag.Get("riviera") == "location" { return value.Field(i).String(), true } } return "", false } func readTags(req interface{}) (map[string]*string, bool) { var value reflect.Value if reflect.ValueOf(req).Kind() == reflect.Ptr { value = reflect.ValueOf(req).Elem() } else { value = reflect.ValueOf(req) } for i := 0; i < value.NumField(); i++ { // iterates through every struct type field tag := value.Type().Field(i).Tag // returns the tag string if tag.Get("riviera") == "tags" { tags := value.Field(i) return tags.Interface().(map[string]*string), true } } return make(map[string]*string), false } func (request *Request) pollForAsynchronousResponse(acceptedResponse *http.Response) (*http.Response, error) { var resp *http.Response = acceptedResponse for { if resp.StatusCode != http.StatusAccepted { return resp, nil } if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" { retryTime, err := strconv.Atoi(strings.TrimSpace(retryAfter)) if err != nil { return nil, err } request.client.logger.Printf("[INFO] Polling pausing for %d seconds as per Retry-After header", retryTime) time.Sleep(time.Duration(retryTime) * time.Second) } pollLocation, err := resp.Location() if err != nil { return nil, err } request.client.logger.Printf("[INFO] Polling %q for operation completion", pollLocation.String()) req, err := retryablehttp.NewRequest("GET", pollLocation.String(), bytes.NewReader([]byte{})) if err != nil { return nil, err } err = request.client.tokenRequester.addAuthorizationToRequest(req) if err != nil { return nil, err } resp, err := request.client.httpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode == http.StatusAccepted { continue } return resp, err } } func defaultARMRequestStruct(request *Request, properties interface{}) interface{} { bodyStruct := struct { Location *string `json:"location,omitempty"` Tags *map[string]*string `json:"tags,omitempty"` Properties interface{} `json:"properties"` }{ Properties: properties, } if location, hasLocation := readLocation(request.Command); hasLocation { bodyStruct.Location = &location } if tags, hasTags := readTags(request.Command); hasTags { if len(tags) > 0 { bodyStruct.Tags = &tags } } return bodyStruct } func defaultARMRequestSerialize(body interface{}) (io.ReadSeeker, error) { jsonEncodedRequest, err := json.Marshal(body) if err != nil { return nil, err } return bytes.NewReader(jsonEncodedRequest), nil } func (request *Request) Execute() (*Response, error) { apiInfo := request.Command.APIInfo() var urlString string // Base URL should already be validated by now so Parse is safe without error handling urlObj, _ := url.Parse(request.client.BaseURL) // Determine whether to use the URLPathFunc or the URI explicitly set in the request if request.URI == nil { urlObj.Path = fmt.Sprintf("/subscriptions/%s/%s", request.client.subscriptionID, strings.TrimPrefix(apiInfo.URLPathFunc(), "/")) urlString = urlObj.String() } else { urlObj.Path = *request.URI urlString = urlObj.String() } // Encode the request body if necessary var body io.ReadSeeker if apiInfo.HasBody() { var bodyStruct interface{} if apiInfo.RequestPropertiesFunc != nil { bodyStruct = defaultARMRequestStruct(request, apiInfo.RequestPropertiesFunc()) } else { bodyStruct = defaultARMRequestStruct(request, request.Command) } serialized, err := defaultARMRequestSerialize(bodyStruct) if err != nil { return nil, err } body = serialized } else { body = bytes.NewReader([]byte{}) } // Create an HTTP request req, err := retryablehttp.NewRequest(apiInfo.Method, urlString, body) if err != nil { return nil, err } query := req.URL.Query() query.Set("api-version", apiInfo.APIVersion) req.URL.RawQuery = query.Encode() if apiInfo.HasBody() { req.Header.Add("Content-Type", "application/json") } err = request.client.tokenRequester.addAuthorizationToRequest(req) if err != nil { return nil, err } httpResponse, err := request.client.httpClient.Do(req) if err != nil { return nil, err } // This is safe to use for every request: we check for it being http.StatusAccepted httpResponse, err = request.pollForAsynchronousResponse(httpResponse) if err != nil { return nil, err } var responseObj interface{} var errorObj *Error if isSuccessCode(httpResponse.StatusCode) { responseObj = apiInfo.ResponseTypeFunc() // The response factory func returns nil as a signal that there is no body if responseObj != nil { responseMap, err := unmarshalFlattenPropertiesAndClose(httpResponse) if err != nil { return nil, err } err = mapstructure.WeakDecode(responseMap, responseObj) if err != nil { return nil, err } } } else { responseMap, err := unmarshalFlattenErrorAndClose(httpResponse) err = mapstructure.WeakDecode(responseMap, &errorObj) if err != nil { return nil, err } errorObj.StatusCode = httpResponse.StatusCode } return &Response{ HTTP: httpResponse, Parsed: responseObj, Error: errorObj, }, nil }