Commit Diff


commit - f62764615c3e0fa8465b5bdf32fd067841e9fb91
commit + 598b53e1e47686b8c47241b874109f3b70c5f2d1
blob - d494622841bcd47aeb9aabaab0545762f13ddacb
blob + ea07625f4a5a3322930143449e92fc2d6a06f910
--- cmd/recursor/recursor.go
+++ cmd/recursor/recursor.go
@@ -8,33 +8,50 @@ import (
 	"olowe.co/dns"
 )
 
-func shouldReject(m *dnsmessage.Message) (bool, dnsmessage.RCode) {
-	if !m.Header.RecursionDesired {
-		return true, dnsmessage.RCodeRefused
-	} else if m.Header.OpCode != dns.OpCodeQUERY {
-		return true, dnsmessage.RCodeRefused
-	} else if len(m.Questions) != 1 {
-		return true, dnsmessage.RCodeFormatError
-	} else if m.Questions[0].Type == dnsmessage.TypeALL {
-		return true, dnsmessage.RCodeNotImplemented
-	} else if m.Questions[0].Class != dnsmessage.ClassINET {
-		return true, dnsmessage.RCodeNotImplemented
+// okQType returns true if t is a query type that we can resolve by
+// recursively querying nameservers.
+func okQType(t dnsmessage.Type) bool {
+	switch t {
+	case dnsmessage.TypeA, dnsmessage.TypeNS, dnsmessage.TypeCNAME, dnsmessage.TypeSOA, dnsmessage.TypePTR, dnsmessage.TypeMX, dnsmessage.TypeTXT, dnsmessage.TypeAAAA, dnsmessage.TypeSRV, dnsmessage.TypeOPT:
+		return true
 	}
-	return false, dnsmessage.RCodeSuccess
+	return false
 }
 
+// rejectHandler is a safeguard to prevent queries we don't want (or support)
+// to be recursively resolved. It returns true if the message was rejected.
+func rejectHandler(w dns.ResponseWriter, qmsg *dnsmessage.Message) bool {
+	if !qmsg.Header.RecursionDesired {
+		dns.Refuse(w, qmsg)
+		return true
+	} else if qmsg.Header.OpCode != dns.OpCodeQUERY {
+		dns.Refuse(w, qmsg)
+		return true
+	} else if len(qmsg.Questions) != 1 {
+		dns.FormatError(w, qmsg)
+		return true
+	}
+	q := qmsg.Questions[0]
+	if !okQType(q.Type) {
+		dns.NotImplemented(w, qmsg)
+		return true
+	} else if q.Class != dnsmessage.ClassINET {
+		dns.NotImplemented(w, qmsg)
+		return true
+	}
+	return false
+}
+
 func handler(w dns.ResponseWriter, qmsg *dnsmessage.Message) {
+	if rejected := rejectHandler(w, qmsg); rejected {
+		return
+	}
+
 	var rmsg dnsmessage.Message
 	rmsg.Header.ID = qmsg.Header.ID
 	rmsg.Header.Response = true
 	rmsg.Header.RecursionAvailable = true
 	rmsg.Questions = qmsg.Questions
-
-	if reject, rc := shouldReject(qmsg); reject {
-		rmsg.Header.RCode = rc
-		w.WriteMsg(rmsg)
-		return
-	}
 	rmsg.RecursionDesired = true
 
 	q := qmsg.Questions[0]
blob - 14d4fe846cd63d5fbe3031e0e11446a2b4242c0c
blob + 668d6c3ac27c5a8ebd08e80cfaa5aed11f0868a6
--- server.go
+++ server.go
@@ -132,21 +132,97 @@ func ListenAndServe(network, addr string, handler Hand
 	return srv.ListenAndServe()
 }
 
-// DefaultHandler responds to all DNS messages identically. Recursivew
-// queries are refused and all others are replied to with a "not
-// implemented" message. It is intended as a safe default for a Server
-// which does not set a Handler.
-func DefaultHandler(w ResponseWriter, msg *dnsmessage.Message) {
-	var rmsg dnsmessage.Message
-	rmsg.Header.ID = msg.Header.ID
-	if msg.Header.RecursionDesired {
-		rmsg.Header.RCode = dnsmessage.RCodeRefused
-		w.WriteMsg(rmsg)
-		return
+// DefaultHandler responds to all DNS messages identically; all message
+// are refused. It is intended as a safe default for a Server which
+// does not set a Handler.
+var DefaultHandler = Refuse
+
+// FormatError replies to the message with a Format Error message.
+func FormatError(w ResponseWriter, msg *dnsmessage.Message) {
+	w.WriteMsg(dnsmessage.Message{
+		Header: dnsmessage.Header{
+			ID:               msg.Header.ID,
+			Response:         true,
+			RecursionDesired: msg.Header.RecursionDesired,
+			RCode:            dnsmessage.RCodeFormatError,
+		},
+		Questions: msg.Questions,
+	})
+}
+
+// ServerFailure replies to the message with a Server Failure (SERVFAIL) message.
+func ServerFailure(w ResponseWriter, msg *dnsmessage.Message) {
+	w.WriteMsg(dnsmessage.Message{
+		Header: dnsmessage.Header{
+			ID:               msg.Header.ID,
+			Response:         true,
+			RecursionDesired: msg.Header.RecursionDesired,
+			RCode:            dnsmessage.RCodeServerFailure,
+		},
+		Questions: msg.Questions,
+	})
+}
+
+// NotImplemented replies to the message with a Format Error message.
+func NotImplemented(w ResponseWriter, msg *dnsmessage.Message) {
+	w.WriteMsg(dnsmessage.Message{
+		Header: dnsmessage.Header{
+			ID:               msg.Header.ID,
+			Response:         true,
+			RecursionDesired: msg.Header.RecursionDesired,
+			RCode:            dnsmessage.RCodeNotImplemented,
+		},
+		Questions: msg.Questions,
+	})
+}
+
+// Refuse replies to the message with a Refused message.
+func Refuse(w ResponseWriter, msg *dnsmessage.Message) {
+	w.WriteMsg(dnsmessage.Message{
+		Header: dnsmessage.Header{
+			ID:               msg.Header.ID,
+			Response:         true,
+			RecursionDesired: msg.Header.RecursionDesired,
+			RCode:            dnsmessage.RCodeRefused,
+		},
+		Questions: msg.Questions,
+	})
+}
+
+// NameError replies to the message with a Name error (NXDOMAIN) message.
+// The SOA resource and its resource header rh are included in the reply.
+// Servers performing recursive resolution should set authoritative to false and
+// authoritative servers should set this to true.
+func NameError(w ResponseWriter, msg *dnsmessage.Message, rh dnsmessage.ResourceHeader, soa dnsmessage.SOAResource, authoritative bool) {
+	buf := make([]byte, 2, 512)
+	header := dnsmessage.Header{
+		ID:               msg.Header.ID,
+		Response:         true,
+		RecursionDesired: msg.Header.RecursionDesired,
+		Authoritative:    authoritative,
+		RCode:            dnsmessage.RCodeNameError,
 	}
-	rmsg.Questions = msg.Questions
-	rmsg.Header.RCode = dnsmessage.RCodeNotImplemented
-	w.WriteMsg(rmsg)
+	builder := dnsmessage.NewBuilder(buf, header)
+	builder.EnableCompression()
+	if err := builder.StartQuestions(); err != nil {
+		panic(err)
+	}
+	for _, q := range msg.Questions {
+		if err := builder.Question(q); err != nil {
+			panic(err)
+		}
+	}
+	if err := builder.StartAuthorities; err != nil {
+		panic(err)
+	}
+	if err := builder.SOAResource(rh, soa); err != nil {
+		panic(err)
+	}
+	buf, err := builder.Finish()
+	if err != nil {
+		panic(err)
+	}
+	w.Write(buf[2:])
 }
 
 // ExtractIPs extracts any IP addresses from resources. An empty slice is