Blob


1 package apub
3 import (
4 "bytes"
5 "crypto"
6 "crypto/rand"
7 "crypto/rsa"
8 "crypto/sha256"
9 "encoding/base64"
10 "fmt"
11 "io"
12 "net/http"
13 "strings"
14 "time"
15 )
17 const requiredSigHeaders = "(request-target) host date digest"
19 // Sign signs the given HTTP request with the matching private key of the
20 // public key available at pubkeyURL.
21 func Sign(req *http.Request, key *rsa.PrivateKey, pubkeyURL string) error {
22 if pubkeyURL == "" {
23 return fmt.Errorf("no pubkey url")
24 }
25 date := time.Now().UTC().Format(http.TimeFormat)
26 req.Header.Set("Date", date)
27 hash := sha256.New()
28 toSign := []string{"(request-target)", "host", "date"}
29 fmt.Fprintln(hash, "(request-target):", strings.ToLower(req.Method), req.URL.Path)
30 fmt.Fprintln(hash, "host:", req.URL.Hostname())
31 fmt.Fprintf(hash, "date: %s", date)
33 if req.Body != nil {
34 // we're adding one more entry to our signature, so one more line.
35 fmt.Fprint(hash, "\n")
36 buf := &bytes.Buffer{}
37 io.Copy(buf, req.Body)
38 req.Body.Close()
39 req.Body = io.NopCloser(buf)
40 digest := sha256.Sum256(buf.Bytes())
41 d := "SHA-256=" + base64.StdEncoding.EncodeToString(digest[:])
42 toSign = append(toSign, "digest")
43 fmt.Fprintf(hash, "digest: %s", d)
44 req.Header.Set("Digest", d)
45 }
46 sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hash.Sum(nil))
47 if err != nil {
48 return err
49 }
50 bsig := base64.StdEncoding.EncodeToString(sig)
52 val := fmt.Sprintf("keyId=%q,algorithm=%q,headers=%q,signature=%q", pubkeyURL, "rsa-sha256", strings.Join(toSign, " "), bsig)
53 req.Header.Set("Signature", val)
54 return nil
55 }
57 type signature struct {
58 keyID string
59 algorithm string
60 headers string
61 signature string
62 }
64 func parseSignatureHeader(line string) (signature, error) {
65 var sig signature
66 for _, v := range strings.Split(line, ",") {
67 name, val, ok := strings.Cut(v, "=")
68 if !ok {
69 return sig, fmt.Errorf("bad field: %s from %s", v, line)
70 }
71 val = strings.Trim(val, `"`)
72 switch name {
73 case "keyId":
74 sig.keyID = val
75 case "algorithm":
76 sig.algorithm = val
77 case "headers":
78 sig.headers = val
79 case "signature":
80 sig.signature = val
81 default:
82 return signature{}, fmt.Errorf("bad field name %s", name)
83 }
84 }
86 if sig.keyID == "" {
87 return sig, fmt.Errorf("missing signature field keyId")
88 } else if sig.algorithm == "" {
89 return sig, fmt.Errorf("missing signature field algorithm")
90 } else if sig.headers == "" {
91 return sig, fmt.Errorf("missing signature field headers")
92 } else if sig.signature == "" {
93 return sig, fmt.Errorf("missing signature field signature")
94 }
95 return sig, nil
96 }