Commit Diff


commit - 707b0c4f5fd7c7acd55fa9e93b8f7344e3e83276
commit + ef08e53ee37488a036de6aacf1895d843ab065a0
blob - fba1dbec971587b8a93307ceddf27dc66ebba28d
blob + bf9410a1356f23098df574452efc9126020f4ea7
--- internal/sip/sip.go
+++ internal/sip/sip.go
@@ -38,13 +38,32 @@ type Request struct {
 
 const magicViaCookie = "z9hG4bK"
 
+const (
+	TransportUDP int = iota
+	TransportTCP
+)
+
+// Via represents the Via field in the header of requests.
 type Via struct {
+	// Transport indicates whether TCP or UDP should be used in
+	// subsequent transactions.
+	Transport int
+	// Address is a hostname or IP address to which responses
+	// should be sent.
 	Address string
-	Branch  string
+	// Branch uniquely identifies transactions from a particular user-agent.
+	Branch string
 }
 
 func (v Via) String() string {
-	return fmt.Sprintf("SIP/2.0/UDP %s;branch=%s%s", v.Address, magicViaCookie, v.Branch)
+	tport := "unknown"
+	switch v.Transport {
+	case TransportUDP:
+		tport = "UDP"
+	case TransportTCP:
+		tport = "TCP"
+	}
+	return fmt.Sprintf("SIP/2.0/%s %s;branch=%s%s", tport, v.Address, magicViaCookie, v.Branch)
 }
 
 func ReadRequest(r io.Reader) (*Request, error) {
@@ -83,6 +102,13 @@ func WriteRequest(w io.Writer, req *Request) (n int64,
 			return 0, fmt.Errorf("missing field %s in header", s)
 		}
 	}
+	if req.Via.Address == "" {
+		return 0, fmt.Errorf("empty address in via header field")
+	} else if req.Via.Branch == "" {
+		return 0, fmt.Errorf("empty branch in via header field")
+	}
+
+	req.Header.Set("Via", req.Via.String())
 	if req.Header.Get("Max-Forwards") == "" {
 		// TODO(otl): find section in RFC recommending 70.
 		// section x.x.x
@@ -91,7 +117,6 @@ func WriteRequest(w io.Writer, req *Request) (n int64,
 	if req.ContentLength > 0 {
 		req.Header.Set("Content-Length", strconv.Itoa(int(req.ContentLength)))
 	}
-	req.Header.Set("Via", req.Via.String())
 
 	buf := &bytes.Buffer{}
 	fmt.Fprintf(buf, "%s %s SIP/2.0\r\n", req.Method, req.URI)
blob - ec1f4354adb5e24810959469771fb2258027319f
blob + 6e2d88dd488c1be65fc6cdaeb96bfd4647f2611d
--- internal/sip/sip_test.go
+++ internal/sip/sip_test.go
@@ -1,13 +1,31 @@
 package sip
 
 import (
-	"fmt"
+	"io"
+	"net/textproto"
 	"os"
 	"strings"
 	"testing"
 )
 
-func TestRequest(t *testing.T) {
+func TestWriteRequest(t *testing.T) {
+	header := make(textproto.MIMEHeader)
+	header.Set("Call-ID", "blabla")
+	header.Set("To", "test <sip:test@example.com>")
+	header.Set("From", "Oliver <sip:o@olowe.co>")
+	header.Set("CSeq", "1 "+MethodRegister)
+	req := &Request{
+		Method: MethodRegister,
+		URI:    "sip:test@example.com",
+		Header: header,
+	}
+	_, err := WriteRequest(io.Discard, req)
+	if err == nil {
+		t.Errorf("no error writing request with zero Via field")
+	}
+}
+
+func TestReadRequest(t *testing.T) {
 	f, err := os.Open("testdata/invite")
 	if err != nil {
 		t.Fatal(err)
@@ -17,6 +35,7 @@ func TestRequest(t *testing.T) {
 	if err != nil {
 		t.Fatal("read request:", err)
 	}
+
 }
 
 func TestResponse(t *testing.T) {
@@ -40,9 +59,8 @@ Content-Length: 131
 	if err != nil {
 		t.Fatal("read message:", err)
 	}
-	resp, err := parseResponse(msg)
+	_, err = parseResponse(msg)
 	if err != nil {
 		t.Fatalf("parse response: %v", err)
 	}
-	fmt.Printf("%+v\n", resp)
 }