commit 7e5b66014d8e51608b3c909ab643e1f91f12b8db from: Oliver Lowe date: Thu Dec 16 12:29:11 2021 UTC More tests commit - 20fd186cb05aed778db0ee7d84c3e413e3ad68fe commit + 7e5b66014d8e51608b3c909ab643e1f91f12b8db blob - b00a41244c0c16699f6f406da8c7ef89533d306d blob + da47d0f5ade635b9e5992f18e81ff88699f52c1c --- dns_test.go +++ dns_test.go @@ -1,164 +1,65 @@ package dns import ( - "errors" - "fmt" - "io" "math/rand" "net" "testing" - "time" "golang.org/x/net/dns/dnsmessage" ) -type fakeDNSConn struct { - net.Conn - server fakeDNSServer - buf []byte - tcp bool -} +var testq = dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} -type fakeDNSPacketConn struct { - net.PacketConn - fakeDNSConn +func resolveBadly(w ResponseWriter, qmsg *dnsmessage.Message) { + rmsg := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: qmsg.Header.ID + 69, + Response: false, + RCode: dnsmessage.RCodeNameError, + }, + Questions: qmsg.Questions, + } + w.WriteMsg(rmsg) } -func (f *fakeDNSPacketConn) Close() error { - return nil -} - -type fakeDNSServer struct { - resolve func(q dnsmessage.Message) (dnsmessage.Message, error) -} - -func resolveWell(q dnsmessage.Message) (dnsmessage.Message, error) { - return dnsmessage.Message{ +func resolveWrongQuestion(w ResponseWriter, qmsg *dnsmessage.Message) { + wrongq := dnsmessage.Question{Name: dnsmessage.MustNewName("blabla.example.org."), Type: dnsmessage.TypeNS, Class: dnsmessage.ClassCHAOS} + rmsg := dnsmessage.Message{ Header: dnsmessage.Header{ - ID: q.Header.ID, + ID: qmsg.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, + Authoritative: true, }, - Questions: q.Questions, - Answers: []dnsmessage.Resource{ - { - Header: dnsmessage.ResourceHeader{ - Name: q.Questions[0].Name, - Type: q.Questions[0].Type, - Class: q.Questions[0].Class, - }, - Body: &dnsmessage.AResource{ - A: [4]byte{0xc0, 0x00, 0x02, 0x01}, - }, - }, - }, - }, nil -} - -var errCrashed = errors.New("crashed") - -func resolveBadly(q dnsmessage.Message) (dnsmessage.Message, error) { - return dnsmessage.Message{}, errCrashed -} - -func (f fakeDNSConn) Close() error { - return nil -} - -func (f *fakeDNSConn) Write(b []byte) (int, error) { - time.Sleep(50 * time.Millisecond) - if len(f.buf) > 0 { - return 0, fmt.Errorf("connection buffer full, refusing overwrite") + Questions: []dnsmessage.Question{wrongq}, } - var qmsg dnsmessage.Message - if f.tcp { - if err := qmsg.Unpack(b[2:]); err != nil { - return len(b), err - } - } else { - if err := qmsg.Unpack(b); err != nil { - return len(b), err - } - } - rmsg, err := f.server.resolve(qmsg) - if err != nil { - return len(b), err - } - packed, err := rmsg.Pack() - if err != nil { - return len(b), err - } - if f.tcp { - l := len(packed) - buf := make([]byte, 2+len(packed)) - buf[0] = byte(l >> 8) - buf[1] = byte(l) - copy(buf[2:], packed) - f.buf = buf - return len(b), nil - } - f.buf = packed - return len(b), err + w.WriteMsg(rmsg) } -func (f *fakeDNSConn) Read(b []byte) (int, error) { - if len(f.buf) > 0 { - n := copy(b, f.buf) - f.buf = f.buf[n:] - return n, nil - } - return 0, io.EOF -} - -func TestGoodConn(t *testing.T) { - qmsg, err := buildmsg("www.example.com.") +func TestBadResolver(t *testing.T) { + srv := Server{network: "udp", addr: "127.0.0.1", Handler: resolveBadly} + conn, err := net.ListenPacket("udp", "127.0.0.1:5359") if err != nil { t.Fatal(err) } - var goodconn fakeDNSConn - goodconn.server.resolve = resolveWell - goodconn.tcp = true - _, err = exchange(qmsg, &goodconn) - if err != nil { - t.Error(err) + go func() { + t.Fatal(srv.ServePacket(conn)) + }() + rmsg, err := Ask(testq, "127.0.0.1:5359") + if err == nil { + t.Error("wanted error, got nil") } -} + t.Log(err) + t.Log("sent:", testq, "received", rmsg) -func TestShitConn(t *testing.T) { - qmsg, err := buildmsg("www.example.com.") - if err != nil { - t.Fatal(err) - } - var shitconn fakeDNSConn - shitconn.server.resolve = resolveBadly - shitconn.tcp = true - _, err = exchange(qmsg, &shitconn) - if !errors.Is(err, errCrashed) { - t.Errorf("wanted error %v, got %v", errCrashed, err) - } -} - -func TestBadMessage(t *testing.T) { - q, err := buildmsg("www.example.com.") - if err != nil { - t.Fatal(err) - } - var shitconn fakeDNSPacketConn - shitconn.server.resolve = func(q dnsmessage.Message) (dnsmessage.Message, error) { - return dnsmessage.Message{ - Header: dnsmessage.Header{ - ID: q.Header.ID + 69, - Response: false, - RCode: dnsmessage.RCodeNameError, - }, - Questions: q.Questions, - }, nil - } - r, err := exchange(q, &shitconn) - if !errors.Is(err, errMismatchedID) { + srv.Handler = resolveWrongQuestion + rmsg, err = Ask(testq, "127.0.0.1:5359") + if err == nil { + t.Error("wanted error, got nil") + } else if err != nil { t.Log(err) - t.Errorf("should error on receiving mismatched message IDs; sent %d, received %d", q.Header.ID, r.Header.ID) } + t.Log("sent:", testq, "received:", rmsg) } func buildmsg(s string) (dnsmessage.Message, error) { blob - 6d63e59a3f88930b69912f9553f3149377ed7bcc blob + a505e6c00f2e448077ec690c4b7b13318e4ec389 --- server_test.go +++ server_test.go @@ -1,7 +1,10 @@ package dns import ( - "golang.org/x/net/dns/dnsmessage" + "crypto/rand" + "io" + "net" + "time" "testing" ) @@ -9,8 +12,8 @@ func TestServer(t *testing.T) { go func() { t.Fatal(ListenAndServe("udp", "127.0.0.1:51111", nil)) }() - q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} - rmsg, err := Ask(q, "127.0.0.1:51111") + time.Sleep(time.Millisecond) + rmsg, err := Ask(testq, "127.0.0.1:51111") if err != nil { t.Errorf("exchange: %v", err) } @@ -21,8 +24,8 @@ func TestStreamServer(t *testing.T) { go func() { t.Fatal(ListenAndServe("tcp", "127.0.0.1:51112", nil)) }() - q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} - rmsg, err := AskTCP(q, "127.0.0.1:51112") + time.Sleep(time.Millisecond) + rmsg, err := AskTCP(testq, "127.0.0.1:51112") if err != nil { t.Errorf("exchange: %v", err) } @@ -35,10 +38,52 @@ func TestEmptyServer(t *testing.T) { t.Fatal(srv.ListenAndServe()) t.Log(srv.addr) }() - q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} - rmsg, err := Ask(q, "127.0.0.1:domain") + rmsg, err := Ask(testq, "127.0.0.1:domain") if err != nil { t.Errorf("exchange: %v", err) } t.Log("response:", rmsg) } + +func TestJunk(t *testing.T) { + addr := "127.0.0.1:5361" + go func() { + t.Fatal(ListenAndServe("tcp", addr, nil)) + }() + time.Sleep(time.Millisecond) + for i := 0; i <= 30; i++ { + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + if _, err := io.CopyN(conn, rand.Reader, 8192); err != nil { + t.Fatal(err) + } + } +} + +func BenchmarkPacketVsStream(b *testing.B) { + addr := "127.0.0.1:51113" + var networks = []string{"udp", "tcp"} + for _, net := range networks { + go func(){ + b.Fatal(ListenAndServe(net, addr, nil)) + }() + b.Run(net, func(b *testing.B) { + for i := 0; i<= b.N; i++ { + if net == "udp" { + if rmsg, err := Ask(testq, addr); err != nil { + b.Log(rmsg) + b.Fatal(err) + } + } else { + if rmsg, err := AskTCP(testq, addr); err != nil { + b.Log(rmsg) + b.Fatal(err) + } + } + } + }) + } +}