Commit Diff


commit - f4835a3962d972a64a38ccdc989c49ab0020ffb6
commit + 0fc4e74ad9ea6060d49dfd2bb0369b798d10ade3
blob - 7fba1263619ee086dd9998b684db9a7948e5949c
blob + 31f33b1c611b25f8cc024ad576501d5d46c0037f
--- dns.go
+++ dns.go
@@ -24,9 +24,18 @@ func Exchange(msg dnsmessage.Message, addr string) (dn
 		return dnsmessage.Message{}, err
 	}
 	defer conn.Close()
-	return send(msg, conn)
+	return exchange(msg, conn)
 }
 
+func ExchangeTCP(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) {
+	conn, err := net.Dial("tcp", addr)
+	if err != nil {
+		return dnsmessage.Message{}, err
+	}
+	defer conn.Close()
+	return exchange(msg, conn)
+}
+
 // ExchangeTLS performs a synchronous DNS-over-TLS exchange with addr and returns its
 // reply to msg.
 func ExchangeTLS(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) {
@@ -35,71 +44,86 @@ func ExchangeTLS(msg dnsmessage.Message, addr string) 
 		return dnsmessage.Message{}, err
 	}
 	defer conn.Close()
-	return send(msg, conn)
+	return exchange(msg, conn)
 }
 
-func send(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) {
-	packed, err := msg.Pack()
+func exchange(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) {
+	if err := send(msg, conn); err != nil {
+		return dnsmessage.Message{}, err
+	}
+	rmsg, err := receive(conn)
 	if err != nil {
 		return dnsmessage.Message{}, err
 	}
-	var b []byte
-	if _, ok := conn.(net.PacketConn); ok {
-		b, err = dnsPacketExchange(packed, conn)
-		if err != nil {
-			return dnsmessage.Message{}, fmt.Errorf("exchange DNS packet: %w", err)
-		}
-	} else {
-		b, err = dnsStreamExchange(packed, conn)
-		if err != nil {
-			return dnsmessage.Message{}, fmt.Errorf("exchange DNS TCP stream: %w", err)
-		}
-	}
-	var rmsg dnsmessage.Message
-	if err := rmsg.Unpack(b); err != nil {
-		return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
-	}
 	if rmsg.Header.ID != msg.Header.ID {
 		return rmsg, errMismatchedID
 	}
 	return rmsg, nil
 }
 
-func dnsPacketExchange(b []byte, conn net.Conn) ([]byte, error) {
-	if _, err := conn.Write(b); err != nil {
-		return nil, err
-	}
-	buf := make([]byte, 512) // max UDP size per RFC?
-	n, err := conn.Read(buf)
+func send(msg dnsmessage.Message, conn net.Conn) error {
+	packed, err := msg.Pack()
 	if err != nil {
-		return nil, err
+		return err
 	}
-	return buf[:n], nil
-}
-
-func dnsStreamExchange(b []byte, conn net.Conn) ([]byte, error) {
+	if _, ok := conn.(net.PacketConn); ok {
+		if _, err := conn.Write(packed); err != nil {
+			return err
+		}
+		return nil
+	}
 	// DNS over TCP requires you to prepend the message with a
 	// 2-octet length field.
-	l := len(b)
+	l := len(packed)
 	m := make([]byte, 2+l)
 	m[0] = byte(l >> 8)
 	m[1] = byte(l)
-	copy(m[2:], b)
+	copy(m[2:], packed)
 	if _, err := conn.Write(m); err != nil {
-		return nil, err
+		return err
 	}
+	return nil
+}
 
-	b = make([]byte, 1280)
-	if _, err := io.ReadFull(conn, b[:2]); err != nil {
-		return nil, fmt.Errorf("read length: %w", err)
+func receive(conn net.Conn) (dnsmessage.Message, error) {
+	var buf []byte
+	var n int
+	var err error
+	if _, ok := conn.(net.PacketConn); ok {
+		buf = make([]byte, 512)
+		n, err = conn.Read(buf)
+		if err != nil {
+			return dnsmessage.Message{}, err
+		}
+	} else {
+		buf = make([]byte, 1280)
+		if _, err := io.ReadFull(conn, buf[:2]); err != nil {
+			return dnsmessage.Message{}, fmt.Errorf("read length: %w", err)
+		}
+		l := int(buf[0])<<8 | int(buf[1])
+		if l > len(buf) {
+			buf = make([]byte, l)
+		}
+		n, err = io.ReadFull(conn, buf[:l])
+		if err != nil {
+			return dnsmessage.Message{}, fmt.Errorf("read after length: %w", err)
+		}
 	}
-	l = int(b[0])<<8 | int(b[1])
-	if l > len(b) {
-		b = make([]byte, l)
+	var msg dnsmessage.Message
+	if err := msg.Unpack(buf[:n]); err != nil {
+		return dnsmessage.Message{}, err
 	}
-	n, err := io.ReadFull(conn, b[:l])
+	return msg, nil
+}
+
+func sendPacket(msg dnsmessage.Message, conn net.PacketConn, addr net.Addr) error {
+	packed, err := msg.Pack()
 	if err != nil {
-		return nil, fmt.Errorf("read after length: %w", err)
+		return err
 	}
-	return b[:n], nil
+	_, err = conn.WriteTo(packed, addr)
+	if err != nil {
+		return err
+	}
+	return nil
 }
blob - /dev/null
blob + b00a41244c0c16699f6f406da8c7ef89533d306d (mode 644)
--- /dev/null
+++ dns_test.go
@@ -0,0 +1,189 @@
+package dns
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"math/rand"
+	"net"
+	"testing"
+	"time"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+type fakeDNSConn struct {
+	net.Conn
+	server fakeDNSServer
+	buf    []byte
+	tcp    bool
+}
+
+type fakeDNSPacketConn struct {
+	net.PacketConn
+	fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+	return nil
+}
+
+type fakeDNSServer struct {
+	resolve func(q dnsmessage.Message) (dnsmessage.Message, error)
+}
+
+func resolveWell(q dnsmessage.Message) (dnsmessage.Message, error) {
+	return dnsmessage.Message{
+		Header: dnsmessage.Header{
+			ID:       q.Header.ID,
+			Response: true,
+			RCode:    dnsmessage.RCodeSuccess,
+		},
+		Questions: q.Questions,
+		Answers: []dnsmessage.Resource{
+			{
+				Header: dnsmessage.ResourceHeader{
+					Name:  q.Questions[0].Name,
+					Type:  q.Questions[0].Type,
+					Class: q.Questions[0].Class,
+				},
+				Body: &dnsmessage.AResource{
+					A: [4]byte{0xc0, 0x00, 0x02, 0x01},
+				},
+			},
+		},
+	}, nil
+}
+
+var errCrashed = errors.New("crashed")
+
+func resolveBadly(q dnsmessage.Message) (dnsmessage.Message, error) {
+	return dnsmessage.Message{}, errCrashed
+}
+
+func (f fakeDNSConn) Close() error {
+	return nil
+}
+
+func (f *fakeDNSConn) Write(b []byte) (int, error) {
+	time.Sleep(50 * time.Millisecond)
+	if len(f.buf) > 0 {
+		return 0, fmt.Errorf("connection buffer full, refusing overwrite")
+	}
+	var qmsg dnsmessage.Message
+	if f.tcp {
+		if err := qmsg.Unpack(b[2:]); err != nil {
+			return len(b), err
+		}
+	} else {
+		if err := qmsg.Unpack(b); err != nil {
+			return len(b), err
+		}
+	}
+	rmsg, err := f.server.resolve(qmsg)
+	if err != nil {
+		return len(b), err
+	}
+	packed, err := rmsg.Pack()
+	if err != nil {
+		return len(b), err
+	}
+	if f.tcp {
+		l := len(packed)
+		buf := make([]byte, 2+len(packed))
+		buf[0] = byte(l >> 8)
+		buf[1] = byte(l)
+		copy(buf[2:], packed)
+		f.buf = buf
+		return len(b), nil
+	}
+	f.buf = packed
+	return len(b), err
+}
+
+func (f *fakeDNSConn) Read(b []byte) (int, error) {
+	if len(f.buf) > 0 {
+		n := copy(b, f.buf)
+		f.buf = f.buf[n:]
+		return n, nil
+	}
+	return 0, io.EOF
+}
+
+func TestGoodConn(t *testing.T) {
+	qmsg, err := buildmsg("www.example.com.")
+	if err != nil {
+		t.Fatal(err)
+	}
+	var goodconn fakeDNSConn
+	goodconn.server.resolve = resolveWell
+	goodconn.tcp = true
+	_, err = exchange(qmsg, &goodconn)
+	if err != nil {
+		t.Error(err)
+	}
+}
+
+func TestShitConn(t *testing.T) {
+	qmsg, err := buildmsg("www.example.com.")
+	if err != nil {
+		t.Fatal(err)
+	}
+	var shitconn fakeDNSConn
+	shitconn.server.resolve = resolveBadly
+	shitconn.tcp = true
+	_, err = exchange(qmsg, &shitconn)
+	if !errors.Is(err, errCrashed) {
+		t.Errorf("wanted error %v, got %v", errCrashed, err)
+	}
+}
+
+func TestBadMessage(t *testing.T) {
+	q, err := buildmsg("www.example.com.")
+	if err != nil {
+		t.Fatal(err)
+	}
+	var shitconn fakeDNSPacketConn
+	shitconn.server.resolve = func(q dnsmessage.Message) (dnsmessage.Message, error) {
+		return dnsmessage.Message{
+			Header: dnsmessage.Header{
+				ID:       q.Header.ID + 69,
+				Response: false,
+				RCode:    dnsmessage.RCodeNameError,
+			},
+			Questions: q.Questions,
+		}, nil
+	}
+	r, err := exchange(q, &shitconn)
+	if !errors.Is(err, errMismatchedID) {
+		t.Log(err)
+		t.Errorf("should error on receiving mismatched message IDs; sent %d, received %d", q.Header.ID, r.Header.ID)
+	}
+}
+
+func buildmsg(s string) (dnsmessage.Message, error) {
+	name, err := dnsmessage.NewName(s)
+	if err != nil {
+		return dnsmessage.Message{}, err
+	}
+	var msg dnsmessage.Message
+	header := dnsmessage.Header{ID: uint16(rand.Intn(8192)), RecursionDesired: true}
+	buf := make([]byte, 2, 512+2)
+	b := dnsmessage.NewBuilder(buf, header)
+	b.EnableCompression()
+	q := dnsmessage.Question{Name: name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
+	if err := b.StartQuestions(); err != nil {
+		return msg, err
+	}
+	if err := b.Question(q); err != nil {
+		return msg, err
+	}
+	packed, err := b.Finish()
+	if err != nil {
+		return msg, err
+	}
+	if err := msg.Unpack(packed[2:]); err != nil {
+		return msg, err
+	}
+	return msg, nil
+}
blob - /dev/null
blob + ec66a2254fe7e7caa2a448bceef86bb8944c3721 (mode 644)
--- /dev/null
+++ server.go
@@ -0,0 +1,102 @@
+package dns
+
+import (
+	"golang.org/x/net/dns/dnsmessage"
+	"net"
+)
+
+type Server struct {
+	network string
+	addr    string
+	handler Handler
+}
+
+type response struct {
+	raddr net.Addr
+	pconn net.PacketConn
+	conn  net.Conn
+}
+
+func (r *response) WriteMsg(msg dnsmessage.Message) error {
+	if r.pconn != nil {
+		return sendPacket(msg, r.pconn, r.raddr)
+	}
+	return send(msg, r.conn)
+}
+
+type ResponseWriter interface {
+	WriteMsg(dnsmessage.Message) error
+}
+
+type Handler func(ResponseWriter, *dnsmessage.Message)
+
+func (srv *Server) ServePacket(conn net.PacketConn) error {
+	for {
+		buf := make([]byte, 512)
+		n, raddr, err := conn.ReadFrom(buf)
+		if err != nil {
+			return err
+		}
+		go func() {
+			var msg dnsmessage.Message
+			if err := msg.Unpack(buf[:n]); err != nil {
+				msg.Header.RCode = dnsmessage.RCodeRefused
+				sendPacket(msg, conn, raddr)
+				return
+			}
+			resp := &response{raddr: raddr, pconn: conn}
+			srv.handler(resp, &msg)
+		}()
+	}
+	return nil
+}
+
+func (srv *Server) Serve(l net.Listener) error {
+	defer l.Close()
+	for {
+		conn, err := l.Accept()
+		if err != nil {
+			return err
+		}
+		msg, _ := receive(conn)
+		go func() {
+			resp := &response{conn: conn}
+			srv.handler(resp, &msg)
+		}()
+	}
+}
+
+func (srv *Server) ListenAndServe() error {
+	switch srv.network {
+	case "udp", "udp4", "udp6", "unixgram":
+		conn, err := net.ListenPacket(srv.network, srv.addr)
+		if err != nil {
+			return err
+		}
+		return srv.ServePacket(conn)
+	default:
+		l, err := net.Listen(srv.network, srv.addr)
+		if err != nil {
+			return err
+		}
+		return srv.Serve(l)
+	}
+}
+
+func ListenAndServe(network, addr string, handler Handler) error {
+	srv := &Server{network: network, addr: addr, handler: handler}
+	return srv.ListenAndServe()
+}
+
+func dumbHandler(w ResponseWriter, msg *dnsmessage.Message) {
+	var rmsg dnsmessage.Message
+	rmsg.Header.ID = msg.Header.ID
+	if msg.Header.RecursionDesired {
+		rmsg.Header.RCode = dnsmessage.RCodeRefused
+		w.WriteMsg(rmsg)
+		return
+	}
+	rmsg.Questions = msg.Questions
+	rmsg.Header.RCode = dnsmessage.RCodeNotImplemented
+	w.WriteMsg(rmsg)
+}
blob - /dev/null
blob + 6a095d3aaf9c46d5f89acecac621ee8571079b35 (mode 644)
--- /dev/null
+++ server_test.go
@@ -0,0 +1,39 @@
+package dns
+
+import (
+	"testing"
+)
+
+func TestServer(t *testing.T) {
+	go func() {
+		if err := ListenAndServe("udp", "127.0.0.1:51111", dumbHandler); err != nil {
+			t.Fatal(err)
+		}
+	}()
+	q, err := buildmsg("www.example.com.")
+	if err != nil {
+		t.Fatalf("create query: %v", err)
+	}
+	rmsg, err := Exchange(q, "127.0.0.1:51111")
+	if err != nil {
+		t.Errorf("exchange: %v", err)
+	}
+	t.Log("response:", rmsg)
+}
+
+func TestStreamServer(t *testing.T) {
+	go func() {
+		if err := ListenAndServe("tcp", "127.0.0.1:51112", dumbHandler); err != nil {
+			t.Fatal(err)
+		}
+	}()
+	q, err := buildmsg("www.example.com.")
+	if err != nil {
+		t.Fatal("create query:", err)
+	}
+	rmsg, err := ExchangeTCP(q, "127.0.0.1:51112")
+	if err != nil {
+		t.Errorf("exchange: %v", err)
+	}
+	t.Log("response:", rmsg)
+}