package cos import ( "crypto/hmac" "crypto/sha1" "fmt" "hash" "net/http" "net/url" "sort" "strings" "sync" "time" ) const sha1SignAlgorithm = "sha1" const privateHeaderPrefix = "x-cos-" const defaultAuthExpire = time.Hour // 需要校验的 Headers 列表 var needSignHeaders = map[string]bool{ "host": true, "range": true, "x-cos-acl": true, "x-cos-grant-read": true, "x-cos-grant-write": true, "x-cos-grant-full-control": true, "response-content-type": true, "response-content-language": true, "response-expires": true, "response-cache-control": true, "response-content-disposition": true, "response-content-encoding": true, "cache-control": true, "content-disposition": true, "content-encoding": true, "content-type": true, "content-length": true, "content-md5": true, "expect": true, "expires": true, "x-cos-content-sha1": true, "x-cos-storage-class": true, "if-modified-since": true, "origin": true, "access-control-request-method": true, "access-control-request-headers": true, "x-cos-object-type": true, } func safeURLEncode(s string) string { s = encodeURIComponent(s) s = strings.Replace(s, "!", "%21", -1) s = strings.Replace(s, "'", "%27", -1) s = strings.Replace(s, "(", "%28", -1) s = strings.Replace(s, ")", "%29", -1) s = strings.Replace(s, "*", "%2A", -1) return s } type valuesSignMap map[string][]string func (vs valuesSignMap) Add(key, value string) { key = strings.ToLower(key) vs[key] = append(vs[key], value) } func (vs valuesSignMap) Encode() string { var keys []string for k := range vs { keys = append(keys, k) } sort.Strings(keys) var pairs []string for _, k := range keys { items := vs[k] sort.Strings(items) for _, val := range items { pairs = append( pairs, fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val))) } } return strings.Join(pairs, "&") } // AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数 type AuthTime struct { SignStartTime time.Time SignEndTime time.Time KeyStartTime time.Time KeyEndTime time.Time } // NewAuthTime 生成 AuthTime 的便捷函数 // // expire: 从现在开始多久过期. func NewAuthTime(expire time.Duration) *AuthTime { signStartTime := time.Now() keyStartTime := signStartTime signEndTime := signStartTime.Add(expire) keyEndTime := signEndTime return &AuthTime{ SignStartTime: signStartTime, SignEndTime: signEndTime, KeyStartTime: keyStartTime, KeyEndTime: keyEndTime, } } // signString return q-sign-time string func (a *AuthTime) signString() string { return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix()) } // keyString return q-key-time string func (a *AuthTime) keyString() string { return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix()) } // newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串 func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string { signTime := authTime.signString() keyTime := authTime.keyString() signKey := calSignKey(secretKey, keyTime) formatHeaders := *new(string) signedHeaderList := *new([]string) formatHeaders, signedHeaderList = genFormatHeaders(req.Header) formatParameters, signedParameterList := genFormatParameters(req.URL.Query()) formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders) stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString) signature := calSignature(signKey, stringToSign) return genAuthorization( secretID, signTime, keyTime, signature, signedHeaderList, signedParameterList, ) } // AddAuthorizationHeader 给 req 增加签名信息 func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) { if secretID == "" { return } auth := newAuthorization(secretID, secretKey, req, authTime, ) if len(sessionToken) > 0 { req.Header.Set("x-cos-security-token", sessionToken) } req.Header.Set("Authorization", auth) } // calSignKey 计算 SignKey func calSignKey(secretKey, keyTime string) string { digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm) return fmt.Sprintf("%x", digest) } // calStringToSign 计算 StringToSign func calStringToSign(signAlgorithm, signTime, formatString string) string { h := sha1.New() h.Write([]byte(formatString)) return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil)) } // calSignature 计算 Signature func calSignature(signKey, stringToSign string) string { digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm) return fmt.Sprintf("%x", digest) } // genAuthorization 生成 Authorization func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string { return strings.Join([]string{ "q-sign-algorithm=" + sha1SignAlgorithm, "q-ak=" + secretID, "q-sign-time=" + signTime, "q-key-time=" + keyTime, "q-header-list=" + strings.Join(signedHeaderList, ";"), "q-url-param-list=" + strings.Join(signedParameterList, ";"), "q-signature=" + signature, }, "&") } // genFormatString 生成 FormatString func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string { formatMethod := strings.ToLower(method) formatURI := uri.Path return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI, formatParameters, formatHeaders, ) } // genFormatParameters 生成 FormatParameters 和 SignedParameterList // instead of the url.Values{} func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) { ps := valuesSignMap{} for key, values := range parameters { key = strings.ToLower(key) for _, value := range values { ps.Add(key, value) signedParameterList = append(signedParameterList, key) } } //formatParameters = strings.ToLower(ps.Encode()) formatParameters = ps.Encode() sort.Strings(signedParameterList) return } // genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) { hs := valuesSignMap{} for key, values := range headers { key = strings.ToLower(key) for _, value := range values { if isSignHeader(key) { hs.Add(key, value) signedHeaderList = append(signedHeaderList, key) } } } formatHeaders = hs.Encode() sort.Strings(signedHeaderList) return } // HMAC 签名 func calHMACDigest(key, msg, signMethod string) []byte { var hashFunc func() hash.Hash switch signMethod { case "sha1": hashFunc = sha1.New default: hashFunc = sha1.New } h := hmac.New(hashFunc, []byte(key)) h.Write([]byte(msg)) return h.Sum(nil) } func isSignHeader(key string) bool { for k, v := range needSignHeaders { if key == k && v { return true } } return strings.HasPrefix(key, privateHeaderPrefix) } // AuthorizationTransport 给请求增加 Authorization header type AuthorizationTransport struct { SecretID string SecretKey string SessionToken string rwLocker sync.RWMutex // 签名多久过期 Expire time.Duration Transport http.RoundTripper } // SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken func (t *AuthorizationTransport) SetCredential(ak, sk, token string) { t.rwLocker.Lock() defer t.rwLocker.Unlock() t.SecretID = ak t.SecretKey = sk t.SessionToken = token } // GetCredential get the ak, sk, token func (t *AuthorizationTransport) GetCredential() (string, string, string) { t.rwLocker.RLock() defer t.rwLocker.RUnlock() return t.SecretID, t.SecretKey, t.SessionToken } // RoundTrip implements the RoundTripper interface. func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = cloneRequest(req) // per RoundTrip contract if t.Expire == time.Duration(0) { t.Expire = defaultAuthExpire } ak, sk, token := t.GetCredential() // 增加 Authorization header authTime := NewAuthTime(t.Expire) AddAuthorizationHeader(ak, sk, token, req, authTime) resp, err := t.transport().RoundTrip(req) return resp, err } func (t *AuthorizationTransport) transport() http.RoundTripper { if t.Transport != nil { return t.Transport } return http.DefaultTransport }