diff --git a/state/remote/http.go b/state/remote/http.go index 4cbf6816a..86f9a4e5b 100644 --- a/state/remote/http.go +++ b/state/remote/http.go @@ -22,13 +22,53 @@ func httpFactory(conf map[string]string) (Client, error) { return nil, fmt.Errorf("missing 'address' configuration") } - url, err := url.Parse(address) + storeURL, err := url.Parse(address) if err != nil { - return nil, fmt.Errorf("failed to parse HTTP URL: %s", err) + return nil, fmt.Errorf("failed to parse address URL: %s", err) } - if url.Scheme != "http" && url.Scheme != "https" { + if storeURL.Scheme != "http" && storeURL.Scheme != "https" { return nil, fmt.Errorf("address must be HTTP or HTTPS") } + storeMethod, ok := conf["store_method"] + if !ok { + storeMethod = "POST" + } + + var lockURL *url.URL + lockAddress, ok := conf["lock_address"] + if ok { + lockURL, err := url.Parse(lockAddress) + if err != nil { + return nil, fmt.Errorf("failed to parse lockAddress URL: %s", err) + } + if lockURL.Scheme != "http" && lockURL.Scheme != "https" { + return nil, fmt.Errorf("lockAddress must be HTTP or HTTPS") + } + } else { + lockURL = nil + } + lockMethod, ok := conf["lock_method"] + if !ok { + lockMethod = "LOCK" + } + + var unlockURL *url.URL + unlockAddress, ok := conf["unlock_address"] + if ok { + unlockURL, err := url.Parse(unlockAddress) + if err != nil { + return nil, fmt.Errorf("failed to parse unlockAddress URL: %s", err) + } + if unlockURL.Scheme != "http" && unlockURL.Scheme != "https" { + return nil, fmt.Errorf("unlockAddress must be HTTP or HTTPS") + } + } else { + unlockURL = nil + } + unlockMethod, ok := conf["unlock_method"] + if !ok { + unlockMethod = "UNLOCK" + } client := &http.Client{} if skipRaw, ok := conf["skip_cert_verification"]; ok { @@ -48,58 +88,72 @@ func httpFactory(conf map[string]string) (Client, error) { } } - supportsLocking := false - if supportsLockingRaw, ok := conf["supports_locking"]; ok { - var err error - supportsLocking, err = strconv.ParseBool(supportsLockingRaw) - if err != nil { - return nil, fmt.Errorf("supports_locking must be boolean") - } + ret := &HTTPClient{ + URL: storeURL, + StoreMethod: storeMethod, + + LockURL: lockURL, + LockMethod: lockMethod, + UnlockURL: unlockURL, + UnlockMethod: unlockMethod, + + Client: client, + Username: conf["username"], + Password: conf["password"], } - ret := &HTTPClient{ - URL: url, - Client: client, - SupportsLocking: supportsLocking, - } - if username, ok := conf["username"]; ok && username != "" { - ret.Username = username - } - if password, ok := conf["password"]; ok && password != "" { - ret.Password = password - } return ret, nil } // HTTPClient is a remote client that stores data in Consul or HTTP REST. type HTTPClient struct { - URL *url.URL - Client *http.Client - Username string - Password string - SupportsLocking bool - lockID string + // Store & Retrieve + URL *url.URL + StoreMethod string + + // Locking + LockURL *url.URL + LockMethod string + UnlockURL *url.URL + UnlockMethod string + + // HTTP + Client *http.Client + Username string + Password string + + lockID string + jsonLockInfo []byte } -func (c *HTTPClient) httpPost(url string, data []byte, what string) (*http.Response, error) { - - // Generate the MD5 - hash := md5.Sum(data) - b64 := base64.StdEncoding.EncodeToString(hash[:]) - - req, err := http.NewRequest("POST", url, bytes.NewReader(data)) - if err != nil { - return nil, fmt.Errorf("Failed to make HTTP request: %s", err) +func (c *HTTPClient) httpRequest(method string, url *url.URL, data *[]byte, what string) (*http.Response, error) { + // If we have data we need a reader + var reader io.Reader = nil + if data != nil { + reader = bytes.NewReader(*data) } - // Prepare the request - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Content-MD5", b64) - req.ContentLength = int64(len(data)) + // Create the request + req, err := http.NewRequest(method, url.String(), reader) + if err != nil { + return nil, fmt.Errorf("Failed to make %s HTTP request: %s", what, err) + } + // Setup basic auth if c.Username != "" { req.SetBasicAuth(c.Username, c.Password) } + // Work with data/body + if data != nil { + req.Header.Set("Content-Type", "application/json") + req.ContentLength = int64(len(*data)) + + // Generate the MD5 + hash := md5.Sum(*data) + b64 := base64.StdEncoding.EncodeToString(hash[:]) + req.Header.Set("Content-MD5", b64) + } + // Make the request resp, err := c.Client.Do(req) if err != nil { @@ -110,20 +164,13 @@ func (c *HTTPClient) httpPost(url string, data []byte, what string) (*http.Respo } func (c *HTTPClient) Lock(info *state.LockInfo) (string, error) { - if !c.SupportsLocking { + if c.LockURL == nil { return "", nil } c.lockID = "" - url := *c.URL - path := url.Path - if len(path) == 0 || path[len(path)-1] != byte('/') { - // add a trailing / - path = fmt.Sprintf("%s/", path) - } - url.Path = fmt.Sprintf("%slock", path) - - resp, err := c.httpPost(url.String(), info.Marshal(), "lock") + jsonLockInfo := info.Marshal() + resp, err := c.httpRequest(c.LockMethod, c.LockURL, &jsonLockInfo, "lock") if err != nil { return "", err } @@ -132,6 +179,7 @@ func (c *HTTPClient) Lock(info *state.LockInfo) (string, error) { switch resp.StatusCode { case http.StatusOK: c.lockID = info.ID + c.jsonLockInfo = jsonLockInfo return info.ID, nil case http.StatusUnauthorized: return "", fmt.Errorf("HTTP remote state endpoint requires auth") @@ -155,26 +203,11 @@ func (c *HTTPClient) Lock(info *state.LockInfo) (string, error) { } func (c *HTTPClient) Unlock(id string) error { - if !c.SupportsLocking { + if c.UnlockURL == nil { return nil } - // copy the target URL - url := *c.URL - path := url.Path - if len(path) == 0 || path[len(path)-1] != byte('/') { - // add a trailing / - path = fmt.Sprintf("%s/", path) - } - url.Path = fmt.Sprintf("%sunlock", path) - - if c.SupportsLocking { - query := url.Query() - query.Set("ID", id) - url.RawQuery = query.Encode() - } - - resp, err := c.httpPost(url.String(), []byte{}, "unlock") + resp, err := c.httpRequest(c.UnlockMethod, c.UnlockURL, &c.jsonLockInfo, "unlock") if err != nil { return err } @@ -189,18 +222,7 @@ func (c *HTTPClient) Unlock(id string) error { } func (c *HTTPClient) Get() (*Payload, error) { - req, err := http.NewRequest("GET", c.URL.String(), nil) - if err != nil { - return nil, err - } - - // Prepare the request - if c.Username != "" { - req.SetBasicAuth(c.Username, c.Password) - } - - // Make the request - resp, err := c.Client.Do(req) + resp, err := c.httpRequest("GET", c.URL, nil, "get state") if err != nil { return nil, err } @@ -262,7 +284,7 @@ func (c *HTTPClient) Put(data []byte) error { // Copy the target URL base := *c.URL - if c.SupportsLocking { + if c.lockID != "" { query := base.Query() query.Set("ID", c.lockID) base.RawQuery = query.Encode() @@ -277,7 +299,11 @@ func (c *HTTPClient) Put(data []byte) error { } */ - resp, err := c.httpPost(base.String(), data, "upload state") + var method string = "POST" + if c.StoreMethod != "" { + method = c.StoreMethod + } + resp, err := c.httpRequest(method, &base, &data, "upload state") if err != nil { return err } @@ -293,20 +319,10 @@ func (c *HTTPClient) Put(data []byte) error { } func (c *HTTPClient) Delete() error { - req, err := http.NewRequest("DELETE", c.URL.String(), nil) - if err != nil { - return fmt.Errorf("Failed to make HTTP request: %s", err) - } - - // Prepare the request - if c.Username != "" { - req.SetBasicAuth(c.Username, c.Password) - } - // Make the request - resp, err := c.Client.Do(req) + resp, err := c.httpRequest("DELETE", c.URL, nil, "delete state") if err != nil { - return fmt.Errorf("Failed to delete state: %s", err) + return err } defer resp.Body.Close() diff --git a/state/remote/http_test.go b/state/remote/http_test.go index 40765b8ee..bed0c5937 100644 --- a/state/remote/http_test.go +++ b/state/remote/http_test.go @@ -30,8 +30,22 @@ func TestHTTPClient(t *testing.T) { client := &HTTPClient{URL: url, Client: cleanhttp.DefaultClient()} testClient(t, client) - a := &HTTPClient{URL: url, Client: cleanhttp.DefaultClient(), SupportsLocking: true} - b := &HTTPClient{URL: url, Client: cleanhttp.DefaultClient(), SupportsLocking: true} + a := &HTTPClient{ + URL: url, + LockURL: url, + LockMethod: "LOCK", + UnlockURL: url, + UnlockMethod: "UNLOCK", + Client: cleanhttp.DefaultClient(), + } + b := &HTTPClient{ + URL: url, + LockURL: url, + LockMethod: "LOCK", + UnlockURL: url, + UnlockMethod: "UNLOCK", + Client: cleanhttp.DefaultClient(), + } TestRemoteLocks(t, a, b) } @@ -45,25 +59,20 @@ func (h *testHTTPHandler) Handle(w http.ResponseWriter, r *http.Request) { case "GET": w.Write(h.Data) case "POST": - switch r.URL.Path { - case "/": - // state - buf := new(bytes.Buffer) - if _, err := io.Copy(buf, r.Body); err != nil { - w.WriteHeader(500) - } - - h.Data = buf.Bytes() - case "/lock": - if h.Locked { - w.WriteHeader(409) - } else { - h.Locked = true - } - case "/unlock": - h.Locked = false + buf := new(bytes.Buffer) + if _, err := io.Copy(buf, r.Body); err != nil { + w.WriteHeader(500) } + h.Data = buf.Bytes() + case "LOCK": + if h.Locked { + w.WriteHeader(409) + } else { + h.Locked = true + } + case "UNLOCK": + h.Locked = false case "DELETE": h.Data = nil w.WriteHeader(200)