commit ef08e53ee37488a036de6aacf1895d843ab065a0 from: Oliver Lowe date: Sat Mar 22 06:16:56 2025 UTC sip: validate Via header field on writing requests commit - 707b0c4f5fd7c7acd55fa9e93b8f7344e3e83276 commit + ef08e53ee37488a036de6aacf1895d843ab065a0 blob - fba1dbec971587b8a93307ceddf27dc66ebba28d blob + bf9410a1356f23098df574452efc9126020f4ea7 --- internal/sip/sip.go +++ internal/sip/sip.go @@ -38,13 +38,32 @@ type Request struct { const magicViaCookie = "z9hG4bK" +const ( + TransportUDP int = iota + TransportTCP +) + +// Via represents the Via field in the header of requests. type Via struct { + // Transport indicates whether TCP or UDP should be used in + // subsequent transactions. + Transport int + // Address is a hostname or IP address to which responses + // should be sent. Address string - Branch string + // Branch uniquely identifies transactions from a particular user-agent. + Branch string } func (v Via) String() string { - return fmt.Sprintf("SIP/2.0/UDP %s;branch=%s%s", v.Address, magicViaCookie, v.Branch) + tport := "unknown" + switch v.Transport { + case TransportUDP: + tport = "UDP" + case TransportTCP: + tport = "TCP" + } + return fmt.Sprintf("SIP/2.0/%s %s;branch=%s%s", tport, v.Address, magicViaCookie, v.Branch) } func ReadRequest(r io.Reader) (*Request, error) { @@ -83,6 +102,13 @@ func WriteRequest(w io.Writer, req *Request) (n int64, return 0, fmt.Errorf("missing field %s in header", s) } } + if req.Via.Address == "" { + return 0, fmt.Errorf("empty address in via header field") + } else if req.Via.Branch == "" { + return 0, fmt.Errorf("empty branch in via header field") + } + + req.Header.Set("Via", req.Via.String()) if req.Header.Get("Max-Forwards") == "" { // TODO(otl): find section in RFC recommending 70. // section x.x.x @@ -91,7 +117,6 @@ func WriteRequest(w io.Writer, req *Request) (n int64, if req.ContentLength > 0 { req.Header.Set("Content-Length", strconv.Itoa(int(req.ContentLength))) } - req.Header.Set("Via", req.Via.String()) buf := &bytes.Buffer{} fmt.Fprintf(buf, "%s %s SIP/2.0\r\n", req.Method, req.URI) blob - ec1f4354adb5e24810959469771fb2258027319f blob + 6e2d88dd488c1be65fc6cdaeb96bfd4647f2611d --- internal/sip/sip_test.go +++ internal/sip/sip_test.go @@ -1,13 +1,31 @@ package sip import ( - "fmt" + "io" + "net/textproto" "os" "strings" "testing" ) -func TestRequest(t *testing.T) { +func TestWriteRequest(t *testing.T) { + header := make(textproto.MIMEHeader) + header.Set("Call-ID", "blabla") + header.Set("To", "test ") + header.Set("From", "Oliver ") + header.Set("CSeq", "1 "+MethodRegister) + req := &Request{ + Method: MethodRegister, + URI: "sip:test@example.com", + Header: header, + } + _, err := WriteRequest(io.Discard, req) + if err == nil { + t.Errorf("no error writing request with zero Via field") + } +} + +func TestReadRequest(t *testing.T) { f, err := os.Open("testdata/invite") if err != nil { t.Fatal(err) @@ -17,6 +35,7 @@ func TestRequest(t *testing.T) { if err != nil { t.Fatal("read request:", err) } + } func TestResponse(t *testing.T) { @@ -40,9 +59,8 @@ Content-Length: 131 if err != nil { t.Fatal("read message:", err) } - resp, err := parseResponse(msg) + _, err = parseResponse(msg) if err != nil { t.Fatalf("parse response: %v", err) } - fmt.Printf("%+v\n", resp) }