Commit Diff


commit - 0fc4e74ad9ea6060d49dfd2bb0369b798d10ade3
commit + a978f917d391a3e682f3931a985c1aa85cf5bbf5
blob - ec66a2254fe7e7caa2a448bceef86bb8944c3721
blob + 35eef0583482deeffc445f518297a8e562f85160
--- server.go
+++ server.go
@@ -31,6 +31,9 @@ type ResponseWriter interface {
 type Handler func(ResponseWriter, *dnsmessage.Message)
 
 func (srv *Server) ServePacket(conn net.PacketConn) error {
+	if srv.handler == nil {
+		srv.handler = DefaultHandler
+	}
 	for {
 		buf := make([]byte, 512)
 		n, raddr, err := conn.ReadFrom(buf)
@@ -53,29 +56,46 @@ func (srv *Server) ServePacket(conn net.PacketConn) er
 
 func (srv *Server) Serve(l net.Listener) error {
 	defer l.Close()
+	if srv.handler == nil {
+		srv.handler = DefaultHandler
+	}
 	for {
 		conn, err := l.Accept()
 		if err != nil {
 			return err
 		}
 		msg, _ := receive(conn)
-		go func() {
-			resp := &response{conn: conn}
-			srv.handler(resp, &msg)
-		}()
+		resp := &response{conn: conn}
+		go srv.handler(resp, &msg)
 	}
 }
 
+func ServePacket(conn net.PacketConn, handler Handler) error {
+	srv := &Server{handler: handler}
+	return srv.ServePacket(conn)
+}
+
+func Serve(l net.Listener, handler Handler) error {
+	srv := &Server{handler: handler}
+	return srv.Serve(l)
+}
+
 func (srv *Server) ListenAndServe() error {
-	switch srv.network {
-	case "udp", "udp4", "udp6", "unixgram":
-		conn, err := net.ListenPacket(srv.network, srv.addr)
+	if srv.addr == "" {
+		srv.addr = ":53"
+	}
+	switch nw := srv.network; nw {
+	case "", "udp", "udp4", "udp6", "unixgram":
+		if nw == "" {
+			nw = "udp"
+		}
+		conn, err := net.ListenPacket(nw, srv.addr)
 		if err != nil {
 			return err
 		}
 		return srv.ServePacket(conn)
 	default:
-		l, err := net.Listen(srv.network, srv.addr)
+		l, err := net.Listen(nw, srv.addr)
 		if err != nil {
 			return err
 		}
@@ -88,7 +108,7 @@ func ListenAndServe(network, addr string, handler Hand
 	return srv.ListenAndServe()
 }
 
-func dumbHandler(w ResponseWriter, msg *dnsmessage.Message) {
+func DefaultHandler(w ResponseWriter, msg *dnsmessage.Message) {
 	var rmsg dnsmessage.Message
 	rmsg.Header.ID = msg.Header.ID
 	if msg.Header.RecursionDesired {
blob - 6a095d3aaf9c46d5f89acecac621ee8571079b35
blob + 86aa5c26cea7ffdec8ca93d56608c78d2fce2c9f
--- server_test.go
+++ server_test.go
@@ -6,7 +6,7 @@ import (
 
 func TestServer(t *testing.T) {
 	go func() {
-		if err := ListenAndServe("udp", "127.0.0.1:51111", dumbHandler); err != nil {
+		if err := ListenAndServe("udp", "127.0.0.1:51111", nil); err != nil {
 			t.Fatal(err)
 		}
 	}()
@@ -23,7 +23,7 @@ func TestServer(t *testing.T) {
 
 func TestStreamServer(t *testing.T) {
 	go func() {
-		if err := ListenAndServe("tcp", "127.0.0.1:51112", dumbHandler); err != nil {
+		if err := ListenAndServe("tcp", "127.0.0.1:51112", nil); err != nil {
 			t.Fatal(err)
 		}
 	}()
@@ -37,3 +37,20 @@ func TestStreamServer(t *testing.T) {
 	}
 	t.Log("response:", rmsg)
 }
+
+func TestEmptyServer(t *testing.T) {
+	srv := &Server{}
+	go func() {
+		t.Fatal(srv.ListenAndServe())
+		t.Log(srv.addr)
+	}()
+	q, err := buildmsg("www.example.com.")
+	if err != nil {
+		t.Fatal("create query:", err)
+	}
+	rmsg, err := Exchange(q, "127.0.0.1:domain")
+	if err != nil {
+		t.Errorf("exchange: %v", err)
+	}
+	t.Log("response:", rmsg)
+}