commit - f4835a3962d972a64a38ccdc989c49ab0020ffb6
commit + 0fc4e74ad9ea6060d49dfd2bb0369b798d10ade3
blob - 7fba1263619ee086dd9998b684db9a7948e5949c
blob + 31f33b1c611b25f8cc024ad576501d5d46c0037f
--- dns.go
+++ dns.go
return dnsmessage.Message{}, err
}
defer conn.Close()
- return send(msg, conn)
+ return exchange(msg, conn)
}
+func ExchangeTCP(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) {
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ return dnsmessage.Message{}, err
+ }
+ defer conn.Close()
+ return exchange(msg, conn)
+}
+
// ExchangeTLS performs a synchronous DNS-over-TLS exchange with addr and returns its
// reply to msg.
func ExchangeTLS(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) {
return dnsmessage.Message{}, err
}
defer conn.Close()
- return send(msg, conn)
+ return exchange(msg, conn)
}
-func send(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) {
- packed, err := msg.Pack()
+func exchange(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) {
+ if err := send(msg, conn); err != nil {
+ return dnsmessage.Message{}, err
+ }
+ rmsg, err := receive(conn)
if err != nil {
return dnsmessage.Message{}, err
}
- var b []byte
- if _, ok := conn.(net.PacketConn); ok {
- b, err = dnsPacketExchange(packed, conn)
- if err != nil {
- return dnsmessage.Message{}, fmt.Errorf("exchange DNS packet: %w", err)
- }
- } else {
- b, err = dnsStreamExchange(packed, conn)
- if err != nil {
- return dnsmessage.Message{}, fmt.Errorf("exchange DNS TCP stream: %w", err)
- }
- }
- var rmsg dnsmessage.Message
- if err := rmsg.Unpack(b); err != nil {
- return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
- }
if rmsg.Header.ID != msg.Header.ID {
return rmsg, errMismatchedID
}
return rmsg, nil
}
-func dnsPacketExchange(b []byte, conn net.Conn) ([]byte, error) {
- if _, err := conn.Write(b); err != nil {
- return nil, err
- }
- buf := make([]byte, 512) // max UDP size per RFC?
- n, err := conn.Read(buf)
+func send(msg dnsmessage.Message, conn net.Conn) error {
+ packed, err := msg.Pack()
if err != nil {
- return nil, err
+ return err
}
- return buf[:n], nil
-}
-
-func dnsStreamExchange(b []byte, conn net.Conn) ([]byte, error) {
+ if _, ok := conn.(net.PacketConn); ok {
+ if _, err := conn.Write(packed); err != nil {
+ return err
+ }
+ return nil
+ }
// DNS over TCP requires you to prepend the message with a
// 2-octet length field.
- l := len(b)
+ l := len(packed)
m := make([]byte, 2+l)
m[0] = byte(l >> 8)
m[1] = byte(l)
- copy(m[2:], b)
+ copy(m[2:], packed)
if _, err := conn.Write(m); err != nil {
- return nil, err
+ return err
}
+ return nil
+}
- b = make([]byte, 1280)
- if _, err := io.ReadFull(conn, b[:2]); err != nil {
- return nil, fmt.Errorf("read length: %w", err)
+func receive(conn net.Conn) (dnsmessage.Message, error) {
+ var buf []byte
+ var n int
+ var err error
+ if _, ok := conn.(net.PacketConn); ok {
+ buf = make([]byte, 512)
+ n, err = conn.Read(buf)
+ if err != nil {
+ return dnsmessage.Message{}, err
+ }
+ } else {
+ buf = make([]byte, 1280)
+ if _, err := io.ReadFull(conn, buf[:2]); err != nil {
+ return dnsmessage.Message{}, fmt.Errorf("read length: %w", err)
+ }
+ l := int(buf[0])<<8 | int(buf[1])
+ if l > len(buf) {
+ buf = make([]byte, l)
+ }
+ n, err = io.ReadFull(conn, buf[:l])
+ if err != nil {
+ return dnsmessage.Message{}, fmt.Errorf("read after length: %w", err)
+ }
}
- l = int(b[0])<<8 | int(b[1])
- if l > len(b) {
- b = make([]byte, l)
+ var msg dnsmessage.Message
+ if err := msg.Unpack(buf[:n]); err != nil {
+ return dnsmessage.Message{}, err
}
- n, err := io.ReadFull(conn, b[:l])
+ return msg, nil
+}
+
+func sendPacket(msg dnsmessage.Message, conn net.PacketConn, addr net.Addr) error {
+ packed, err := msg.Pack()
if err != nil {
- return nil, fmt.Errorf("read after length: %w", err)
+ return err
}
- return b[:n], nil
+ _, err = conn.WriteTo(packed, addr)
+ if err != nil {
+ return err
+ }
+ return nil
}
blob - /dev/null
blob + b00a41244c0c16699f6f406da8c7ef89533d306d (mode 644)
--- /dev/null
+++ dns_test.go
+package dns
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+type fakeDNSConn struct {
+ net.Conn
+ server fakeDNSServer
+ buf []byte
+ tcp bool
+}
+
+type fakeDNSPacketConn struct {
+ net.PacketConn
+ fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+ return nil
+}
+
+type fakeDNSServer struct {
+ resolve func(q dnsmessage.Message) (dnsmessage.Message, error)
+}
+
+func resolveWell(q dnsmessage.Message) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: q.Questions[0].Type,
+ Class: q.Questions[0].Class,
+ },
+ Body: &dnsmessage.AResource{
+ A: [4]byte{0xc0, 0x00, 0x02, 0x01},
+ },
+ },
+ },
+ }, nil
+}
+
+var errCrashed = errors.New("crashed")
+
+func resolveBadly(q dnsmessage.Message) (dnsmessage.Message, error) {
+ return dnsmessage.Message{}, errCrashed
+}
+
+func (f fakeDNSConn) Close() error {
+ return nil
+}
+
+func (f *fakeDNSConn) Write(b []byte) (int, error) {
+ time.Sleep(50 * time.Millisecond)
+ if len(f.buf) > 0 {
+ return 0, fmt.Errorf("connection buffer full, refusing overwrite")
+ }
+ var qmsg dnsmessage.Message
+ if f.tcp {
+ if err := qmsg.Unpack(b[2:]); err != nil {
+ return len(b), err
+ }
+ } else {
+ if err := qmsg.Unpack(b); err != nil {
+ return len(b), err
+ }
+ }
+ rmsg, err := f.server.resolve(qmsg)
+ if err != nil {
+ return len(b), err
+ }
+ packed, err := rmsg.Pack()
+ if err != nil {
+ return len(b), err
+ }
+ if f.tcp {
+ l := len(packed)
+ buf := make([]byte, 2+len(packed))
+ buf[0] = byte(l >> 8)
+ buf[1] = byte(l)
+ copy(buf[2:], packed)
+ f.buf = buf
+ return len(b), nil
+ }
+ f.buf = packed
+ return len(b), err
+}
+
+func (f *fakeDNSConn) Read(b []byte) (int, error) {
+ if len(f.buf) > 0 {
+ n := copy(b, f.buf)
+ f.buf = f.buf[n:]
+ return n, nil
+ }
+ return 0, io.EOF
+}
+
+func TestGoodConn(t *testing.T) {
+ qmsg, err := buildmsg("www.example.com.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var goodconn fakeDNSConn
+ goodconn.server.resolve = resolveWell
+ goodconn.tcp = true
+ _, err = exchange(qmsg, &goodconn)
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestShitConn(t *testing.T) {
+ qmsg, err := buildmsg("www.example.com.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var shitconn fakeDNSConn
+ shitconn.server.resolve = resolveBadly
+ shitconn.tcp = true
+ _, err = exchange(qmsg, &shitconn)
+ if !errors.Is(err, errCrashed) {
+ t.Errorf("wanted error %v, got %v", errCrashed, err)
+ }
+}
+
+func TestBadMessage(t *testing.T) {
+ q, err := buildmsg("www.example.com.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var shitconn fakeDNSPacketConn
+ shitconn.server.resolve = func(q dnsmessage.Message) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID + 69,
+ Response: false,
+ RCode: dnsmessage.RCodeNameError,
+ },
+ Questions: q.Questions,
+ }, nil
+ }
+ r, err := exchange(q, &shitconn)
+ if !errors.Is(err, errMismatchedID) {
+ t.Log(err)
+ t.Errorf("should error on receiving mismatched message IDs; sent %d, received %d", q.Header.ID, r.Header.ID)
+ }
+}
+
+func buildmsg(s string) (dnsmessage.Message, error) {
+ name, err := dnsmessage.NewName(s)
+ if err != nil {
+ return dnsmessage.Message{}, err
+ }
+ var msg dnsmessage.Message
+ header := dnsmessage.Header{ID: uint16(rand.Intn(8192)), RecursionDesired: true}
+ buf := make([]byte, 2, 512+2)
+ b := dnsmessage.NewBuilder(buf, header)
+ b.EnableCompression()
+ q := dnsmessage.Question{Name: name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
+ if err := b.StartQuestions(); err != nil {
+ return msg, err
+ }
+ if err := b.Question(q); err != nil {
+ return msg, err
+ }
+ packed, err := b.Finish()
+ if err != nil {
+ return msg, err
+ }
+ if err := msg.Unpack(packed[2:]); err != nil {
+ return msg, err
+ }
+ return msg, nil
+}
blob - /dev/null
blob + ec66a2254fe7e7caa2a448bceef86bb8944c3721 (mode 644)
--- /dev/null
+++ server.go
+package dns
+
+import (
+ "golang.org/x/net/dns/dnsmessage"
+ "net"
+)
+
+type Server struct {
+ network string
+ addr string
+ handler Handler
+}
+
+type response struct {
+ raddr net.Addr
+ pconn net.PacketConn
+ conn net.Conn
+}
+
+func (r *response) WriteMsg(msg dnsmessage.Message) error {
+ if r.pconn != nil {
+ return sendPacket(msg, r.pconn, r.raddr)
+ }
+ return send(msg, r.conn)
+}
+
+type ResponseWriter interface {
+ WriteMsg(dnsmessage.Message) error
+}
+
+type Handler func(ResponseWriter, *dnsmessage.Message)
+
+func (srv *Server) ServePacket(conn net.PacketConn) error {
+ for {
+ buf := make([]byte, 512)
+ n, raddr, err := conn.ReadFrom(buf)
+ if err != nil {
+ return err
+ }
+ go func() {
+ var msg dnsmessage.Message
+ if err := msg.Unpack(buf[:n]); err != nil {
+ msg.Header.RCode = dnsmessage.RCodeRefused
+ sendPacket(msg, conn, raddr)
+ return
+ }
+ resp := &response{raddr: raddr, pconn: conn}
+ srv.handler(resp, &msg)
+ }()
+ }
+ return nil
+}
+
+func (srv *Server) Serve(l net.Listener) error {
+ defer l.Close()
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ return err
+ }
+ msg, _ := receive(conn)
+ go func() {
+ resp := &response{conn: conn}
+ srv.handler(resp, &msg)
+ }()
+ }
+}
+
+func (srv *Server) ListenAndServe() error {
+ switch srv.network {
+ case "udp", "udp4", "udp6", "unixgram":
+ conn, err := net.ListenPacket(srv.network, srv.addr)
+ if err != nil {
+ return err
+ }
+ return srv.ServePacket(conn)
+ default:
+ l, err := net.Listen(srv.network, srv.addr)
+ if err != nil {
+ return err
+ }
+ return srv.Serve(l)
+ }
+}
+
+func ListenAndServe(network, addr string, handler Handler) error {
+ srv := &Server{network: network, addr: addr, handler: handler}
+ return srv.ListenAndServe()
+}
+
+func dumbHandler(w ResponseWriter, msg *dnsmessage.Message) {
+ var rmsg dnsmessage.Message
+ rmsg.Header.ID = msg.Header.ID
+ if msg.Header.RecursionDesired {
+ rmsg.Header.RCode = dnsmessage.RCodeRefused
+ w.WriteMsg(rmsg)
+ return
+ }
+ rmsg.Questions = msg.Questions
+ rmsg.Header.RCode = dnsmessage.RCodeNotImplemented
+ w.WriteMsg(rmsg)
+}
blob - /dev/null
blob + 6a095d3aaf9c46d5f89acecac621ee8571079b35 (mode 644)
--- /dev/null
+++ server_test.go
+package dns
+
+import (
+ "testing"
+)
+
+func TestServer(t *testing.T) {
+ go func() {
+ if err := ListenAndServe("udp", "127.0.0.1:51111", dumbHandler); err != nil {
+ t.Fatal(err)
+ }
+ }()
+ q, err := buildmsg("www.example.com.")
+ if err != nil {
+ t.Fatalf("create query: %v", err)
+ }
+ rmsg, err := Exchange(q, "127.0.0.1:51111")
+ if err != nil {
+ t.Errorf("exchange: %v", err)
+ }
+ t.Log("response:", rmsg)
+}
+
+func TestStreamServer(t *testing.T) {
+ go func() {
+ if err := ListenAndServe("tcp", "127.0.0.1:51112", dumbHandler); err != nil {
+ t.Fatal(err)
+ }
+ }()
+ q, err := buildmsg("www.example.com.")
+ if err != nil {
+ t.Fatal("create query:", err)
+ }
+ rmsg, err := ExchangeTCP(q, "127.0.0.1:51112")
+ if err != nil {
+ t.Errorf("exchange: %v", err)
+ }
+ t.Log("response:", rmsg)
+}