commit 0fc4e74ad9ea6060d49dfd2bb0369b798d10ade3 from: Oliver Lowe date: Thu Dec 9 07:09:36 2021 UTC dump commit - f4835a3962d972a64a38ccdc989c49ab0020ffb6 commit + 0fc4e74ad9ea6060d49dfd2bb0369b798d10ade3 blob - 7fba1263619ee086dd9998b684db9a7948e5949c blob + 31f33b1c611b25f8cc024ad576501d5d46c0037f --- dns.go +++ dns.go @@ -24,9 +24,18 @@ func Exchange(msg dnsmessage.Message, addr string) (dn return dnsmessage.Message{}, err } defer conn.Close() - return send(msg, conn) + return exchange(msg, conn) } +func ExchangeTCP(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return dnsmessage.Message{}, err + } + defer conn.Close() + return exchange(msg, conn) +} + // ExchangeTLS performs a synchronous DNS-over-TLS exchange with addr and returns its // reply to msg. func ExchangeTLS(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) { @@ -35,71 +44,86 @@ func ExchangeTLS(msg dnsmessage.Message, addr string) return dnsmessage.Message{}, err } defer conn.Close() - return send(msg, conn) + return exchange(msg, conn) } -func send(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) { - packed, err := msg.Pack() +func exchange(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) { + if err := send(msg, conn); err != nil { + return dnsmessage.Message{}, err + } + rmsg, err := receive(conn) if err != nil { return dnsmessage.Message{}, err } - var b []byte - if _, ok := conn.(net.PacketConn); ok { - b, err = dnsPacketExchange(packed, conn) - if err != nil { - return dnsmessage.Message{}, fmt.Errorf("exchange DNS packet: %w", err) - } - } else { - b, err = dnsStreamExchange(packed, conn) - if err != nil { - return dnsmessage.Message{}, fmt.Errorf("exchange DNS TCP stream: %w", err) - } - } - var rmsg dnsmessage.Message - if err := rmsg.Unpack(b); err != nil { - return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err) - } if rmsg.Header.ID != msg.Header.ID { return rmsg, errMismatchedID } return rmsg, nil } -func dnsPacketExchange(b []byte, conn net.Conn) ([]byte, error) { - if _, err := conn.Write(b); err != nil { - return nil, err - } - buf := make([]byte, 512) // max UDP size per RFC? - n, err := conn.Read(buf) +func send(msg dnsmessage.Message, conn net.Conn) error { + packed, err := msg.Pack() if err != nil { - return nil, err + return err } - return buf[:n], nil -} - -func dnsStreamExchange(b []byte, conn net.Conn) ([]byte, error) { + if _, ok := conn.(net.PacketConn); ok { + if _, err := conn.Write(packed); err != nil { + return err + } + return nil + } // DNS over TCP requires you to prepend the message with a // 2-octet length field. - l := len(b) + l := len(packed) m := make([]byte, 2+l) m[0] = byte(l >> 8) m[1] = byte(l) - copy(m[2:], b) + copy(m[2:], packed) if _, err := conn.Write(m); err != nil { - return nil, err + return err } + return nil +} - b = make([]byte, 1280) - if _, err := io.ReadFull(conn, b[:2]); err != nil { - return nil, fmt.Errorf("read length: %w", err) +func receive(conn net.Conn) (dnsmessage.Message, error) { + var buf []byte + var n int + var err error + if _, ok := conn.(net.PacketConn); ok { + buf = make([]byte, 512) + n, err = conn.Read(buf) + if err != nil { + return dnsmessage.Message{}, err + } + } else { + buf = make([]byte, 1280) + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return dnsmessage.Message{}, fmt.Errorf("read length: %w", err) + } + l := int(buf[0])<<8 | int(buf[1]) + if l > len(buf) { + buf = make([]byte, l) + } + n, err = io.ReadFull(conn, buf[:l]) + if err != nil { + return dnsmessage.Message{}, fmt.Errorf("read after length: %w", err) + } } - l = int(b[0])<<8 | int(b[1]) - if l > len(b) { - b = make([]byte, l) + var msg dnsmessage.Message + if err := msg.Unpack(buf[:n]); err != nil { + return dnsmessage.Message{}, err } - n, err := io.ReadFull(conn, b[:l]) + return msg, nil +} + +func sendPacket(msg dnsmessage.Message, conn net.PacketConn, addr net.Addr) error { + packed, err := msg.Pack() if err != nil { - return nil, fmt.Errorf("read after length: %w", err) + return err } - return b[:n], nil + _, err = conn.WriteTo(packed, addr) + if err != nil { + return err + } + return nil } blob - /dev/null blob + b00a41244c0c16699f6f406da8c7ef89533d306d (mode 644) --- /dev/null +++ dns_test.go @@ -0,0 +1,189 @@ +package dns + +import ( + "errors" + "fmt" + "io" + "math/rand" + "net" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +type fakeDNSConn struct { + net.Conn + server fakeDNSServer + buf []byte + tcp bool +} + +type fakeDNSPacketConn struct { + net.PacketConn + fakeDNSConn +} + +func (f *fakeDNSPacketConn) Close() error { + return nil +} + +type fakeDNSServer struct { + resolve func(q dnsmessage.Message) (dnsmessage.Message, error) +} + +func resolveWell(q dnsmessage.Message) (dnsmessage.Message, error) { + return dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: q.Questions, + Answers: []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: q.Questions[0].Type, + Class: q.Questions[0].Class, + }, + Body: &dnsmessage.AResource{ + A: [4]byte{0xc0, 0x00, 0x02, 0x01}, + }, + }, + }, + }, nil +} + +var errCrashed = errors.New("crashed") + +func resolveBadly(q dnsmessage.Message) (dnsmessage.Message, error) { + return dnsmessage.Message{}, errCrashed +} + +func (f fakeDNSConn) Close() error { + return nil +} + +func (f *fakeDNSConn) Write(b []byte) (int, error) { + time.Sleep(50 * time.Millisecond) + if len(f.buf) > 0 { + return 0, fmt.Errorf("connection buffer full, refusing overwrite") + } + var qmsg dnsmessage.Message + if f.tcp { + if err := qmsg.Unpack(b[2:]); err != nil { + return len(b), err + } + } else { + if err := qmsg.Unpack(b); err != nil { + return len(b), err + } + } + rmsg, err := f.server.resolve(qmsg) + if err != nil { + return len(b), err + } + packed, err := rmsg.Pack() + if err != nil { + return len(b), err + } + if f.tcp { + l := len(packed) + buf := make([]byte, 2+len(packed)) + buf[0] = byte(l >> 8) + buf[1] = byte(l) + copy(buf[2:], packed) + f.buf = buf + return len(b), nil + } + f.buf = packed + return len(b), err +} + +func (f *fakeDNSConn) Read(b []byte) (int, error) { + if len(f.buf) > 0 { + n := copy(b, f.buf) + f.buf = f.buf[n:] + return n, nil + } + return 0, io.EOF +} + +func TestGoodConn(t *testing.T) { + qmsg, err := buildmsg("www.example.com.") + if err != nil { + t.Fatal(err) + } + var goodconn fakeDNSConn + goodconn.server.resolve = resolveWell + goodconn.tcp = true + _, err = exchange(qmsg, &goodconn) + if err != nil { + t.Error(err) + } +} + +func TestShitConn(t *testing.T) { + qmsg, err := buildmsg("www.example.com.") + if err != nil { + t.Fatal(err) + } + var shitconn fakeDNSConn + shitconn.server.resolve = resolveBadly + shitconn.tcp = true + _, err = exchange(qmsg, &shitconn) + if !errors.Is(err, errCrashed) { + t.Errorf("wanted error %v, got %v", errCrashed, err) + } +} + +func TestBadMessage(t *testing.T) { + q, err := buildmsg("www.example.com.") + if err != nil { + t.Fatal(err) + } + var shitconn fakeDNSPacketConn + shitconn.server.resolve = func(q dnsmessage.Message) (dnsmessage.Message, error) { + return dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID + 69, + Response: false, + RCode: dnsmessage.RCodeNameError, + }, + Questions: q.Questions, + }, nil + } + r, err := exchange(q, &shitconn) + if !errors.Is(err, errMismatchedID) { + t.Log(err) + t.Errorf("should error on receiving mismatched message IDs; sent %d, received %d", q.Header.ID, r.Header.ID) + } +} + +func buildmsg(s string) (dnsmessage.Message, error) { + name, err := dnsmessage.NewName(s) + if err != nil { + return dnsmessage.Message{}, err + } + var msg dnsmessage.Message + header := dnsmessage.Header{ID: uint16(rand.Intn(8192)), RecursionDesired: true} + buf := make([]byte, 2, 512+2) + b := dnsmessage.NewBuilder(buf, header) + b.EnableCompression() + q := dnsmessage.Question{Name: name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} + if err := b.StartQuestions(); err != nil { + return msg, err + } + if err := b.Question(q); err != nil { + return msg, err + } + packed, err := b.Finish() + if err != nil { + return msg, err + } + if err := msg.Unpack(packed[2:]); err != nil { + return msg, err + } + return msg, nil +} blob - /dev/null blob + ec66a2254fe7e7caa2a448bceef86bb8944c3721 (mode 644) --- /dev/null +++ server.go @@ -0,0 +1,102 @@ +package dns + +import ( + "golang.org/x/net/dns/dnsmessage" + "net" +) + +type Server struct { + network string + addr string + handler Handler +} + +type response struct { + raddr net.Addr + pconn net.PacketConn + conn net.Conn +} + +func (r *response) WriteMsg(msg dnsmessage.Message) error { + if r.pconn != nil { + return sendPacket(msg, r.pconn, r.raddr) + } + return send(msg, r.conn) +} + +type ResponseWriter interface { + WriteMsg(dnsmessage.Message) error +} + +type Handler func(ResponseWriter, *dnsmessage.Message) + +func (srv *Server) ServePacket(conn net.PacketConn) error { + for { + buf := make([]byte, 512) + n, raddr, err := conn.ReadFrom(buf) + if err != nil { + return err + } + go func() { + var msg dnsmessage.Message + if err := msg.Unpack(buf[:n]); err != nil { + msg.Header.RCode = dnsmessage.RCodeRefused + sendPacket(msg, conn, raddr) + return + } + resp := &response{raddr: raddr, pconn: conn} + srv.handler(resp, &msg) + }() + } + return nil +} + +func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + for { + conn, err := l.Accept() + if err != nil { + return err + } + msg, _ := receive(conn) + go func() { + resp := &response{conn: conn} + srv.handler(resp, &msg) + }() + } +} + +func (srv *Server) ListenAndServe() error { + switch srv.network { + case "udp", "udp4", "udp6", "unixgram": + conn, err := net.ListenPacket(srv.network, srv.addr) + if err != nil { + return err + } + return srv.ServePacket(conn) + default: + l, err := net.Listen(srv.network, srv.addr) + if err != nil { + return err + } + return srv.Serve(l) + } +} + +func ListenAndServe(network, addr string, handler Handler) error { + srv := &Server{network: network, addr: addr, handler: handler} + return srv.ListenAndServe() +} + +func dumbHandler(w ResponseWriter, msg *dnsmessage.Message) { + var rmsg dnsmessage.Message + rmsg.Header.ID = msg.Header.ID + if msg.Header.RecursionDesired { + rmsg.Header.RCode = dnsmessage.RCodeRefused + w.WriteMsg(rmsg) + return + } + rmsg.Questions = msg.Questions + rmsg.Header.RCode = dnsmessage.RCodeNotImplemented + w.WriteMsg(rmsg) +} blob - /dev/null blob + 6a095d3aaf9c46d5f89acecac621ee8571079b35 (mode 644) --- /dev/null +++ server_test.go @@ -0,0 +1,39 @@ +package dns + +import ( + "testing" +) + +func TestServer(t *testing.T) { + go func() { + if err := ListenAndServe("udp", "127.0.0.1:51111", dumbHandler); err != nil { + t.Fatal(err) + } + }() + q, err := buildmsg("www.example.com.") + if err != nil { + t.Fatalf("create query: %v", err) + } + rmsg, err := Exchange(q, "127.0.0.1:51111") + if err != nil { + t.Errorf("exchange: %v", err) + } + t.Log("response:", rmsg) +} + +func TestStreamServer(t *testing.T) { + go func() { + if err := ListenAndServe("tcp", "127.0.0.1:51112", dumbHandler); err != nil { + t.Fatal(err) + } + }() + q, err := buildmsg("www.example.com.") + if err != nil { + t.Fatal("create query:", err) + } + rmsg, err := ExchangeTCP(q, "127.0.0.1:51112") + if err != nil { + t.Errorf("exchange: %v", err) + } + t.Log("response:", rmsg) +}