Commit Diff


commit - 16fb0dea8a6bd2b750ffd7f706401125f18d222a
commit + d5834ac9c24e2b94ddd01468976b08de9bb90ce8
blob - d452495bebba48c909a222b91c286b153ca7caed
blob + d494622841bcd47aeb9aabaab0545762f13ddacb
--- cmd/recursor/recursor.go
+++ cmd/recursor/recursor.go
@@ -1,10 +1,9 @@
 package main
 
 import (
-	"os"
 	"fmt"
 	"golang.org/x/net/dns/dnsmessage"
-	"sync"
+	"os"
 
 	"olowe.co/dns"
 )
@@ -18,6 +17,8 @@ func shouldReject(m *dnsmessage.Message) (bool, dnsmes
 		return true, dnsmessage.RCodeFormatError
 	} else if m.Questions[0].Type == dnsmessage.TypeALL {
 		return true, dnsmessage.RCodeNotImplemented
+	} else if m.Questions[0].Class != dnsmessage.ClassINET {
+		return true, dnsmessage.RCodeNotImplemented
 	}
 	return false, dnsmessage.RCodeSuccess
 }
@@ -37,16 +38,6 @@ func handler(w dns.ResponseWriter, qmsg *dnsmessage.Me
 	rmsg.RecursionDesired = true
 
 	q := qmsg.Questions[0]
-	cache.RLock()
-	if answers, ok := cache.m[q]; ok {
-		rmsg.Answers = answers
-		w.WriteMsg(rmsg)
-		cache.RUnlock()
-		fmt.Fprintf(os.Stderr, "cache served %s %s\n", q.Name, q.Type)
-		return
-	}
-	cache.RUnlock()
-
 	resolved, err := resolveFromRoot(q)
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
@@ -56,24 +47,14 @@ func handler(w dns.ResponseWriter, qmsg *dnsmessage.Me
 	}
 	rmsg.Header.RCode = resolved.Header.RCode
 	rmsg.Answers = resolved.Answers
-	cache.Lock()
-	cache.m[q] = rmsg.Answers
-	fmt.Fprintf(os.Stderr, "cached %s %s\n", q.Name, q.Type)
-	cache.Unlock()
 	if len(rmsg.Answers) == 0 {
 		rmsg.Authorities = resolved.Authorities
 		w.WriteMsg(rmsg)
 		return
 	}
-	rmsg.Answers = resolved.Answers
 	w.WriteMsg(rmsg)
 }
 
-var cache = struct{
-	m map[dnsmessage.Question][]dnsmessage.Resource
-	sync.RWMutex
-}{m: make(map[dnsmessage.Question][]dnsmessage.Resource)}
-
 func main() {
 	fmt.Fprintln(os.Stderr, dns.ListenAndServe("udp", "", handler))
 }
blob - /dev/null
blob + f74cab14bc53ea8705412a15b159af915b68c931 (mode 644)
--- /dev/null
+++ cmd/recursor/cache.go
@@ -0,0 +1,42 @@
+package main
+
+import (
+	"sync"
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+var cache = struct {
+	m map[dnsmessage.Name]map[dnsmessage.Type][]dnsmessage.Resource
+	sync.RWMutex
+}{m: make(map[dnsmessage.Name]map[dnsmessage.Type][]dnsmessage.Resource)}
+
+func lookup(n dnsmessage.Name, t dnsmessage.Type) ([]dnsmessage.Resource, bool) {
+	cache.RLock()
+	if rr, ok := cache.m[n][t]; ok {
+		cache.RUnlock()
+		if expired(n, t) {
+			cache.Lock()
+			delete(cache.m[n], t)
+			cache.Unlock()
+			return nil, false
+		}
+		return rr, true
+	}
+	l := len(cache.m[n])
+	cache.RUnlock()
+	if l < 1 {
+		cache.Lock()
+		cache.m[n] = make(map[dnsmessage.Type][]dnsmessage.Resource)
+		cache.Unlock()
+	}
+	return nil, false
+}
+
+func insert(n dnsmessage.Name, t dnsmessage.Type, rrs []dnsmessage.Resource) {
+	cache.Lock()
+	cache.m[n][t] = rrs
+	cache.Unlock()
+	return
+}
+
+func expired(n dnsmessage.Name, t dnsmessage.Type) bool { return false }
blob - 5744b2f2b490d44a1acbcea6014b23a663aaf424
blob + 7aa4fa1052d05864c09d1a8ee9c858ab1f03d72f
--- cmd/recursor/resolve.go
+++ cmd/recursor/resolve.go
@@ -2,13 +2,22 @@ package main
 
 import (
 	"fmt"
+	"golang.org/x/net/dns/dnsmessage"
 	"net"
+	"os"
 	"strings"
-	"golang.org/x/net/dns/dnsmessage"
 
 	"olowe.co/dns"
 )
 
+const rootA = "198.41.0.4"
+const rootB = "199.9.14.201"
+const rootC = "192.33.4.12"
+const rootD = "199.7.91.13"
+const rootE = "192.203.230.10"
+
+var roots []net.IP = []net.IP{net.ParseIP(rootA), net.ParseIP(rootB), net.ParseIP(rootC)}
+
 // appends the DNS port to the IP to be used in a dial string.
 func ip2dial(ip net.IP) string {
 	return net.JoinHostPort(ip.String(), "domain")
@@ -18,6 +27,16 @@ func isIPv6(ip net.IP) bool {
 	return strings.Contains(ip.String(), ":")
 }
 
+func filterRRs(rrs []dnsmessage.Resource, n dnsmessage.Name, t dnsmessage.Type) []dnsmessage.Resource {
+	var matches []dnsmessage.Resource
+	for _, r := range rrs {
+		if (r.Header.Name == n && r.Header.Type == t) {
+			matches = append(matches, r)
+		}
+	}
+	return matches
+}
+
 func nextServerAddrs(resources []dnsmessage.Resource) []net.IP {
 	var next []net.IP
 	for _, r := range resources {
@@ -31,13 +50,6 @@ func nextServerAddrs(resources []dnsmessage.Resource) 
 	return next
 }
 
-const rootA = "198.41.0.4"
-const rootB = "199.9.14.201"
-const rootC = "192.33.4.12"
-const rootD = "199.7.91.13"
-const rootE = "192.203.230.10"
-var roots []net.IP = []net.IP{net.ParseIP(rootA), net.ParseIP(rootB), net.ParseIP(rootC)}
-
 func resolveFromRoot(q dnsmessage.Question) (dnsmessage.Message, error) {
 	return resolve(q, roots)
 }
@@ -45,13 +57,23 @@ func resolveFromRoot(q dnsmessage.Question) (dnsmessag
 func resolve(q dnsmessage.Question, next []net.IP) (dnsmessage.Message, error) {
 	var rmsg dnsmessage.Message
 	var err error
+	if rrs, ok := lookup(q.Name, q.Type); ok {
+		fmt.Fprintln(os.Stderr, "cache served", q.Name, q.Type)
+		return dnsmessage.Message{Answers: rrs}, nil
+	}
+	fmt.Fprintln(os.Stderr, "cache miss", q.Name, q.Type)
+
 	for _, ip := range next {
 		// Aussie Broadband doesn't support IPv6 yet!
 		if isIPv6(ip) {
 			continue
 		}
+		fmt.Fprintf(os.Stderr, "asking %s for %s %s\n", ip, q.Name, q.Type)
 		rmsg, err = dns.Ask(q, ip2dial(ip))
 		if rmsg.Header.Authoritative {
+			fmt.Println("got auth answer")
+			insert(q.Name, q.Type, rmsg.Answers)
+			fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type)
 			return rmsg, err
 		} else if rmsg.Header.RCode == dnsmessage.RCodeSuccess && err == nil {
 			break
@@ -60,13 +82,24 @@ func resolve(q dnsmessage.Question, next []net.IP) (dn
 	if err != nil {
 		return dnsmessage.Message{}, fmt.Errorf("resolve %s: %w", q.Name, err)
 	}
+	fmt.Println("no auth answer")
 
-	// no authoritative answer, so start looking for hints of who to ask next
-	if len(rmsg.Additionals) > 0 {
-		return resolve(q, nextServerAddrs(rmsg.Additionals))
+	if len(rmsg.Authorities) > 0 {
+		if _, ok := lookup(rmsg.Authorities[0].Header.Name, rmsg.Authorities[0].Header.Type); !ok {
+			insert(rmsg.Authorities[0].Header.Name, rmsg.Authorities[0].Header.Type, rmsg.Authorities)
+			fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type)
+		}
 	}
+	for _, a := range rmsg.Additionals {
+		matches := filterRRs(rmsg.Additionals, a.Header.Name, a.Header.Type)
+		if _, ok := lookup(a.Header.Name, a.Header.Type); !ok {
+			insert(a.Header.Name, a.Header.Type, matches)
+			fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type)
+		}
+	}
 
-	// no hints in additionals, check authorities
+	// get the IP addresses of the nameservers we were told about, then
+	// ask the same question to them
 	if len(rmsg.Authorities) > 0 {
 		for _, a := range rmsg.Authorities {
 			switch b := a.Body.(type) {
@@ -80,6 +113,8 @@ func resolve(q dnsmessage.Question, next []net.IP) (dn
 					return resolve(q, nextServerAddrs(rmsg.Answers))
 				}
 				return resolve(q, nextServerAddrs(rmsg.Additionals))
+			default:
+				return rmsg, fmt.Errorf("unexpected authority resource type %s", a.Header.Type)
 			}
 		}
 	}
blob - /dev/null
blob + aa0c6d1bcbc8c8c0cfc99c9148292378f1437e75 (mode 644)
--- /dev/null
+++ cmd/recursor/recursor_test.go
@@ -0,0 +1,119 @@
+package main
+
+import (
+	"fmt"
+	"golang.org/x/net/dns/dnsmessage"
+	"os"
+	"testing"
+
+	"olowe.co/dns"
+)
+
+var tquery dnsmessage.Message = dnsmessage.Message{
+	Header: dnsmessage.Header{
+		ID:               69,
+		RecursionDesired: true,
+	},
+	Questions: []dnsmessage.Question{
+		dnsmessage.Question{
+			Name:  dnsmessage.MustNewName("www.example.com."),
+			Type:  dnsmessage.TypeAAAA,
+			Class: dnsmessage.ClassINET,
+		},
+	},
+}
+
+const testAddr string = "127.0.0.1:5359"
+const quad9 string = "9.9.9.9:domain"
+
+func compareMsg(want, got dnsmessage.Message) error {
+	if want.Header != got.Header {
+		fmt.Errorf("mismatched headers")
+	}
+	if len(want.Answers) != len(got.Answers) {
+		return fmt.Errorf("mismatched answer count")
+	}
+	if (len(want.Answers) == 0 || len(got.Answers) == 0) {
+		return fmt.Errorf("unsupported comparison of empty answer messages")
+	}
+	wantaddr, ok := want.Answers[0].Body.(*dnsmessage.AAAAResource)
+	if !ok {
+		return fmt.Errorf("unexpected resource type from external resolver")
+	}
+	gotaddr, ok := got.Answers[0].Body.(*dnsmessage.AAAAResource)
+	if !ok {
+		return fmt.Errorf("unexpected resource type from our resolver")
+	}
+	if wantaddr.AAAA != gotaddr.AAAA {
+		return fmt.Errorf("wanted %s got %s", wantaddr.AAAA, gotaddr.AAAA)
+	}
+	return nil
+}
+
+func TestMain(m *testing.M) {
+	go func() {
+		if err := dns.ListenAndServe("udp", testAddr, handler); err != nil {
+			fmt.Println(err)
+			os.Exit(1)
+		}
+	}()
+	os.Exit(m.Run())
+}
+
+func TestRecursor(t *testing.T) {
+	wanted, err := dns.Exchange(tquery, quad9)
+	if err != nil {
+		fmt.Fprintf(os.Stderr, "skipping %s: %v\n", t.Name(), err)
+		t.Skip("query internet DNS:", err)
+	}
+	got, err := dns.Exchange(tquery, testAddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("wanted: %+v got %+v", wanted, got)
+	if err := compareMsg(wanted, got); err != nil {
+		t.Error(err)
+	}
+	// answer should come from cache
+	for i := 0; i <= 1; i++ {
+		got, err = dns.Exchange(tquery, testAddr)
+		if err = compareMsg(wanted, got); err != nil {
+			t.Error("resolve from cache:", err)
+		}
+	}
+	q := tquery
+	q.Questions[0].Name = dnsmessage.MustNewName("www.example.net.")
+	for i := 0; i <= 1; i++ {
+		if _, err = dns.Exchange(q, testAddr); err != nil {
+			t.Error("resolve from cache:", err)
+		}
+	}
+	t.Logf("wanted: %+v got %+v", wanted, got)
+}
+
+func TestNXDomain(t *testing.T) {
+	var wanted, got dnsmessage.Message
+	var err error
+	wanted, err = dns.Exchange(tquery, quad9)
+	if err != nil {
+		fmt.Fprintf(os.Stderr, "skipping %s: %v\n", t.Name(), err)
+		t.Skip("query internet DNS:", err)
+	}
+	q := tquery
+	q.Questions[0].Name = dnsmessage.MustNewName("nxdomain.example.com.")
+	wanted, err = dns.Exchange(q, quad9)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// try twice: first for fresh response, second for cached response
+	for i := 0; i <= 1; i++ {
+		got, err = dns.Exchange(q, testAddr)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if wanted.Header != got.Header {
+			t.Error("mismatched headers")
+		}
+	}
+	t.Logf("wanted: %+v got %+v", wanted, got)
+}