Commit Diff


commit - 20fd186cb05aed778db0ee7d84c3e413e3ad68fe
commit + 7e5b66014d8e51608b3c909ab643e1f91f12b8db
blob - b00a41244c0c16699f6f406da8c7ef89533d306d
blob + da47d0f5ade635b9e5992f18e81ff88699f52c1c
--- dns_test.go
+++ dns_test.go
@@ -1,164 +1,65 @@
 package dns
 
 import (
-	"errors"
-	"fmt"
-	"io"
 	"math/rand"
 	"net"
 	"testing"
-	"time"
 
 	"golang.org/x/net/dns/dnsmessage"
 )
 
-type fakeDNSConn struct {
-	net.Conn
-	server fakeDNSServer
-	buf    []byte
-	tcp    bool
-}
+var testq = dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
 
-type fakeDNSPacketConn struct {
-	net.PacketConn
-	fakeDNSConn
+func resolveBadly(w ResponseWriter, qmsg *dnsmessage.Message) {
+		rmsg := dnsmessage.Message{
+			Header: dnsmessage.Header{
+				ID:       qmsg.Header.ID + 69,
+				Response: false,
+				RCode:    dnsmessage.RCodeNameError,
+			},
+			Questions: qmsg.Questions,
+		}
+		w.WriteMsg(rmsg)
 }
 
-func (f *fakeDNSPacketConn) Close() error {
-	return nil
-}
-
-type fakeDNSServer struct {
-	resolve func(q dnsmessage.Message) (dnsmessage.Message, error)
-}
-
-func resolveWell(q dnsmessage.Message) (dnsmessage.Message, error) {
-	return dnsmessage.Message{
+func resolveWrongQuestion(w ResponseWriter, qmsg *dnsmessage.Message) {
+	wrongq := dnsmessage.Question{Name: dnsmessage.MustNewName("blabla.example.org."), Type: dnsmessage.TypeNS, Class: dnsmessage.ClassCHAOS}
+	rmsg := dnsmessage.Message{
 		Header: dnsmessage.Header{
-			ID:       q.Header.ID,
+			ID:       qmsg.Header.ID,
 			Response: true,
 			RCode:    dnsmessage.RCodeSuccess,
+			Authoritative: true,
 		},
-		Questions: q.Questions,
-		Answers: []dnsmessage.Resource{
-			{
-				Header: dnsmessage.ResourceHeader{
-					Name:  q.Questions[0].Name,
-					Type:  q.Questions[0].Type,
-					Class: q.Questions[0].Class,
-				},
-				Body: &dnsmessage.AResource{
-					A: [4]byte{0xc0, 0x00, 0x02, 0x01},
-				},
-			},
-		},
-	}, nil
-}
-
-var errCrashed = errors.New("crashed")
-
-func resolveBadly(q dnsmessage.Message) (dnsmessage.Message, error) {
-	return dnsmessage.Message{}, errCrashed
-}
-
-func (f fakeDNSConn) Close() error {
-	return nil
-}
-
-func (f *fakeDNSConn) Write(b []byte) (int, error) {
-	time.Sleep(50 * time.Millisecond)
-	if len(f.buf) > 0 {
-		return 0, fmt.Errorf("connection buffer full, refusing overwrite")
+		Questions: []dnsmessage.Question{wrongq},
 	}
-	var qmsg dnsmessage.Message
-	if f.tcp {
-		if err := qmsg.Unpack(b[2:]); err != nil {
-			return len(b), err
-		}
-	} else {
-		if err := qmsg.Unpack(b); err != nil {
-			return len(b), err
-		}
-	}
-	rmsg, err := f.server.resolve(qmsg)
-	if err != nil {
-		return len(b), err
-	}
-	packed, err := rmsg.Pack()
-	if err != nil {
-		return len(b), err
-	}
-	if f.tcp {
-		l := len(packed)
-		buf := make([]byte, 2+len(packed))
-		buf[0] = byte(l >> 8)
-		buf[1] = byte(l)
-		copy(buf[2:], packed)
-		f.buf = buf
-		return len(b), nil
-	}
-	f.buf = packed
-	return len(b), err
+	w.WriteMsg(rmsg)
 }
 
-func (f *fakeDNSConn) Read(b []byte) (int, error) {
-	if len(f.buf) > 0 {
-		n := copy(b, f.buf)
-		f.buf = f.buf[n:]
-		return n, nil
-	}
-	return 0, io.EOF
-}
-
-func TestGoodConn(t *testing.T) {
-	qmsg, err := buildmsg("www.example.com.")
+func TestBadResolver(t *testing.T) {
+	srv := Server{network: "udp", addr: "127.0.0.1", Handler: resolveBadly}
+	conn, err := net.ListenPacket("udp", "127.0.0.1:5359")
 	if err != nil {
 		t.Fatal(err)
 	}
-	var goodconn fakeDNSConn
-	goodconn.server.resolve = resolveWell
-	goodconn.tcp = true
-	_, err = exchange(qmsg, &goodconn)
-	if err != nil {
-		t.Error(err)
+	go func() {
+		t.Fatal(srv.ServePacket(conn))
+	}()
+	rmsg, err := Ask(testq, "127.0.0.1:5359")
+	if err == nil {
+		t.Error("wanted error, got nil")
 	}
-}
+	t.Log(err)
+	t.Log("sent:", testq, "received", rmsg)
 
-func TestShitConn(t *testing.T) {
-	qmsg, err := buildmsg("www.example.com.")
-	if err != nil {
-		t.Fatal(err)
-	}
-	var shitconn fakeDNSConn
-	shitconn.server.resolve = resolveBadly
-	shitconn.tcp = true
-	_, err = exchange(qmsg, &shitconn)
-	if !errors.Is(err, errCrashed) {
-		t.Errorf("wanted error %v, got %v", errCrashed, err)
-	}
-}
-
-func TestBadMessage(t *testing.T) {
-	q, err := buildmsg("www.example.com.")
-	if err != nil {
-		t.Fatal(err)
-	}
-	var shitconn fakeDNSPacketConn
-	shitconn.server.resolve = func(q dnsmessage.Message) (dnsmessage.Message, error) {
-		return dnsmessage.Message{
-			Header: dnsmessage.Header{
-				ID:       q.Header.ID + 69,
-				Response: false,
-				RCode:    dnsmessage.RCodeNameError,
-			},
-			Questions: q.Questions,
-		}, nil
-	}
-	r, err := exchange(q, &shitconn)
-	if !errors.Is(err, errMismatchedID) {
+	srv.Handler = resolveWrongQuestion
+	rmsg, err = Ask(testq, "127.0.0.1:5359")
+	if err == nil {
+		t.Error("wanted error, got nil")
+	} else if err != nil {
 		t.Log(err)
-		t.Errorf("should error on receiving mismatched message IDs; sent %d, received %d", q.Header.ID, r.Header.ID)
 	}
+	t.Log("sent:", testq, "received:", rmsg)
 }
 
 func buildmsg(s string) (dnsmessage.Message, error) {
blob - 6d63e59a3f88930b69912f9553f3149377ed7bcc
blob + a505e6c00f2e448077ec690c4b7b13318e4ec389
--- server_test.go
+++ server_test.go
@@ -1,7 +1,10 @@
 package dns
 
 import (
-	"golang.org/x/net/dns/dnsmessage"
+	"crypto/rand"
+	"io"
+	"net"
+	"time"
 	"testing"
 )
 
@@ -9,8 +12,8 @@ func TestServer(t *testing.T) {
 	go func() {
 		t.Fatal(ListenAndServe("udp", "127.0.0.1:51111", nil))
 	}()
-	q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
-	rmsg, err := Ask(q, "127.0.0.1:51111")
+	time.Sleep(time.Millisecond)
+	rmsg, err := Ask(testq, "127.0.0.1:51111")
 	if err != nil {
 		t.Errorf("exchange: %v", err)
 	}
@@ -21,8 +24,8 @@ func TestStreamServer(t *testing.T) {
 	go func() {
 		t.Fatal(ListenAndServe("tcp", "127.0.0.1:51112", nil))
 	}()
-	q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
-	rmsg, err := AskTCP(q, "127.0.0.1:51112")
+	time.Sleep(time.Millisecond)
+	rmsg, err := AskTCP(testq, "127.0.0.1:51112")
 	if err != nil {
 		t.Errorf("exchange: %v", err)
 	}
@@ -35,10 +38,52 @@ func TestEmptyServer(t *testing.T) {
 		t.Fatal(srv.ListenAndServe())
 		t.Log(srv.addr)
 	}()
-	q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
-	rmsg, err := Ask(q, "127.0.0.1:domain")
+	rmsg, err := Ask(testq, "127.0.0.1:domain")
 	if err != nil {
 		t.Errorf("exchange: %v", err)
 	}
 	t.Log("response:", rmsg)
 }
+
+func TestJunk(t *testing.T) {
+	addr := "127.0.0.1:5361"
+	go func() {
+		t.Fatal(ListenAndServe("tcp", addr, nil))
+	}()
+	time.Sleep(time.Millisecond)
+	for i := 0; i <= 30; i++ {
+		conn, err := net.Dial("tcp", addr)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer conn.Close()
+		if _, err := io.CopyN(conn, rand.Reader, 8192); err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func BenchmarkPacketVsStream(b *testing.B) {
+	addr := "127.0.0.1:51113"
+	var networks = []string{"udp", "tcp"}
+	for _, net := range networks {
+		go func(){
+			b.Fatal(ListenAndServe(net, addr, nil))
+		}()
+		b.Run(net, func(b *testing.B) {
+			for i := 0; i<= b.N; i++ {
+				if net == "udp" {
+					if rmsg, err := Ask(testq, addr); err != nil {
+						b.Log(rmsg)
+						b.Fatal(err)
+					}
+				} else {
+					if rmsg, err := AskTCP(testq, addr); err != nil {
+						b.Log(rmsg)
+						b.Fatal(err)
+					}
+				}
+			}
+		})
+	}
+}