commit a978f917d391a3e682f3931a985c1aa85cf5bbf5 from: Oliver Lowe date: Thu Dec 9 07:57:22 2021 UTC server: make default, empty server safe We won't get nil pointer derefs when someone doesn't set a Handler function. While here add a test for it commit - 0fc4e74ad9ea6060d49dfd2bb0369b798d10ade3 commit + a978f917d391a3e682f3931a985c1aa85cf5bbf5 blob - ec66a2254fe7e7caa2a448bceef86bb8944c3721 blob + 35eef0583482deeffc445f518297a8e562f85160 --- server.go +++ server.go @@ -31,6 +31,9 @@ type ResponseWriter interface { type Handler func(ResponseWriter, *dnsmessage.Message) func (srv *Server) ServePacket(conn net.PacketConn) error { + if srv.handler == nil { + srv.handler = DefaultHandler + } for { buf := make([]byte, 512) n, raddr, err := conn.ReadFrom(buf) @@ -53,29 +56,46 @@ func (srv *Server) ServePacket(conn net.PacketConn) er func (srv *Server) Serve(l net.Listener) error { defer l.Close() + if srv.handler == nil { + srv.handler = DefaultHandler + } for { conn, err := l.Accept() if err != nil { return err } msg, _ := receive(conn) - go func() { - resp := &response{conn: conn} - srv.handler(resp, &msg) - }() + resp := &response{conn: conn} + go srv.handler(resp, &msg) } } +func ServePacket(conn net.PacketConn, handler Handler) error { + srv := &Server{handler: handler} + return srv.ServePacket(conn) +} + +func Serve(l net.Listener, handler Handler) error { + srv := &Server{handler: handler} + return srv.Serve(l) +} + func (srv *Server) ListenAndServe() error { - switch srv.network { - case "udp", "udp4", "udp6", "unixgram": - conn, err := net.ListenPacket(srv.network, srv.addr) + if srv.addr == "" { + srv.addr = ":53" + } + switch nw := srv.network; nw { + case "", "udp", "udp4", "udp6", "unixgram": + if nw == "" { + nw = "udp" + } + conn, err := net.ListenPacket(nw, srv.addr) if err != nil { return err } return srv.ServePacket(conn) default: - l, err := net.Listen(srv.network, srv.addr) + l, err := net.Listen(nw, srv.addr) if err != nil { return err } @@ -88,7 +108,7 @@ func ListenAndServe(network, addr string, handler Hand return srv.ListenAndServe() } -func dumbHandler(w ResponseWriter, msg *dnsmessage.Message) { +func DefaultHandler(w ResponseWriter, msg *dnsmessage.Message) { var rmsg dnsmessage.Message rmsg.Header.ID = msg.Header.ID if msg.Header.RecursionDesired { blob - 6a095d3aaf9c46d5f89acecac621ee8571079b35 blob + 86aa5c26cea7ffdec8ca93d56608c78d2fce2c9f --- server_test.go +++ server_test.go @@ -6,7 +6,7 @@ import ( func TestServer(t *testing.T) { go func() { - if err := ListenAndServe("udp", "127.0.0.1:51111", dumbHandler); err != nil { + if err := ListenAndServe("udp", "127.0.0.1:51111", nil); err != nil { t.Fatal(err) } }() @@ -23,7 +23,7 @@ func TestServer(t *testing.T) { func TestStreamServer(t *testing.T) { go func() { - if err := ListenAndServe("tcp", "127.0.0.1:51112", dumbHandler); err != nil { + if err := ListenAndServe("tcp", "127.0.0.1:51112", nil); err != nil { t.Fatal(err) } }() @@ -37,3 +37,20 @@ func TestStreamServer(t *testing.T) { } t.Log("response:", rmsg) } + +func TestEmptyServer(t *testing.T) { + srv := &Server{} + go func() { + t.Fatal(srv.ListenAndServe()) + t.Log(srv.addr) + }() + q, err := buildmsg("www.example.com.") + if err != nil { + t.Fatal("create query:", err) + } + rmsg, err := Exchange(q, "127.0.0.1:domain") + if err != nil { + t.Errorf("exchange: %v", err) + } + t.Log("response:", rmsg) +}