commit d5834ac9c24e2b94ddd01468976b08de9bb90ce8 from: Oliver Lowe date: Sat Dec 18 03:14:22 2021 UTC cmd/recursor: rewrite cache and recursor algorithm Now we only ever send queries to nameservers, not just things we happen to find in the additional section of a reply. This change also does caching more nicely; cache any records that we get from nameservers. Authoritative answers get priority. Fixes: https://todo.sr.ht/~otl/dns/4 commit - 16fb0dea8a6bd2b750ffd7f706401125f18d222a commit + d5834ac9c24e2b94ddd01468976b08de9bb90ce8 blob - d452495bebba48c909a222b91c286b153ca7caed blob + d494622841bcd47aeb9aabaab0545762f13ddacb --- cmd/recursor/recursor.go +++ cmd/recursor/recursor.go @@ -1,10 +1,9 @@ package main import ( - "os" "fmt" "golang.org/x/net/dns/dnsmessage" - "sync" + "os" "olowe.co/dns" ) @@ -18,6 +17,8 @@ func shouldReject(m *dnsmessage.Message) (bool, dnsmes return true, dnsmessage.RCodeFormatError } else if m.Questions[0].Type == dnsmessage.TypeALL { return true, dnsmessage.RCodeNotImplemented + } else if m.Questions[0].Class != dnsmessage.ClassINET { + return true, dnsmessage.RCodeNotImplemented } return false, dnsmessage.RCodeSuccess } @@ -37,16 +38,6 @@ func handler(w dns.ResponseWriter, qmsg *dnsmessage.Me rmsg.RecursionDesired = true q := qmsg.Questions[0] - cache.RLock() - if answers, ok := cache.m[q]; ok { - rmsg.Answers = answers - w.WriteMsg(rmsg) - cache.RUnlock() - fmt.Fprintf(os.Stderr, "cache served %s %s\n", q.Name, q.Type) - return - } - cache.RUnlock() - resolved, err := resolveFromRoot(q) if err != nil { fmt.Fprintln(os.Stderr, err) @@ -56,24 +47,14 @@ func handler(w dns.ResponseWriter, qmsg *dnsmessage.Me } rmsg.Header.RCode = resolved.Header.RCode rmsg.Answers = resolved.Answers - cache.Lock() - cache.m[q] = rmsg.Answers - fmt.Fprintf(os.Stderr, "cached %s %s\n", q.Name, q.Type) - cache.Unlock() if len(rmsg.Answers) == 0 { rmsg.Authorities = resolved.Authorities w.WriteMsg(rmsg) return } - rmsg.Answers = resolved.Answers w.WriteMsg(rmsg) } -var cache = struct{ - m map[dnsmessage.Question][]dnsmessage.Resource - sync.RWMutex -}{m: make(map[dnsmessage.Question][]dnsmessage.Resource)} - func main() { fmt.Fprintln(os.Stderr, dns.ListenAndServe("udp", "", handler)) } blob - /dev/null blob + f74cab14bc53ea8705412a15b159af915b68c931 (mode 644) --- /dev/null +++ cmd/recursor/cache.go @@ -0,0 +1,42 @@ +package main + +import ( + "sync" + "golang.org/x/net/dns/dnsmessage" +) + +var cache = struct { + m map[dnsmessage.Name]map[dnsmessage.Type][]dnsmessage.Resource + sync.RWMutex +}{m: make(map[dnsmessage.Name]map[dnsmessage.Type][]dnsmessage.Resource)} + +func lookup(n dnsmessage.Name, t dnsmessage.Type) ([]dnsmessage.Resource, bool) { + cache.RLock() + if rr, ok := cache.m[n][t]; ok { + cache.RUnlock() + if expired(n, t) { + cache.Lock() + delete(cache.m[n], t) + cache.Unlock() + return nil, false + } + return rr, true + } + l := len(cache.m[n]) + cache.RUnlock() + if l < 1 { + cache.Lock() + cache.m[n] = make(map[dnsmessage.Type][]dnsmessage.Resource) + cache.Unlock() + } + return nil, false +} + +func insert(n dnsmessage.Name, t dnsmessage.Type, rrs []dnsmessage.Resource) { + cache.Lock() + cache.m[n][t] = rrs + cache.Unlock() + return +} + +func expired(n dnsmessage.Name, t dnsmessage.Type) bool { return false } blob - 5744b2f2b490d44a1acbcea6014b23a663aaf424 blob + 7aa4fa1052d05864c09d1a8ee9c858ab1f03d72f --- cmd/recursor/resolve.go +++ cmd/recursor/resolve.go @@ -2,13 +2,22 @@ package main import ( "fmt" + "golang.org/x/net/dns/dnsmessage" "net" + "os" "strings" - "golang.org/x/net/dns/dnsmessage" "olowe.co/dns" ) +const rootA = "198.41.0.4" +const rootB = "199.9.14.201" +const rootC = "192.33.4.12" +const rootD = "199.7.91.13" +const rootE = "192.203.230.10" + +var roots []net.IP = []net.IP{net.ParseIP(rootA), net.ParseIP(rootB), net.ParseIP(rootC)} + // appends the DNS port to the IP to be used in a dial string. func ip2dial(ip net.IP) string { return net.JoinHostPort(ip.String(), "domain") @@ -18,6 +27,16 @@ func isIPv6(ip net.IP) bool { return strings.Contains(ip.String(), ":") } +func filterRRs(rrs []dnsmessage.Resource, n dnsmessage.Name, t dnsmessage.Type) []dnsmessage.Resource { + var matches []dnsmessage.Resource + for _, r := range rrs { + if (r.Header.Name == n && r.Header.Type == t) { + matches = append(matches, r) + } + } + return matches +} + func nextServerAddrs(resources []dnsmessage.Resource) []net.IP { var next []net.IP for _, r := range resources { @@ -31,13 +50,6 @@ func nextServerAddrs(resources []dnsmessage.Resource) return next } -const rootA = "198.41.0.4" -const rootB = "199.9.14.201" -const rootC = "192.33.4.12" -const rootD = "199.7.91.13" -const rootE = "192.203.230.10" -var roots []net.IP = []net.IP{net.ParseIP(rootA), net.ParseIP(rootB), net.ParseIP(rootC)} - func resolveFromRoot(q dnsmessage.Question) (dnsmessage.Message, error) { return resolve(q, roots) } @@ -45,13 +57,23 @@ func resolveFromRoot(q dnsmessage.Question) (dnsmessag func resolve(q dnsmessage.Question, next []net.IP) (dnsmessage.Message, error) { var rmsg dnsmessage.Message var err error + if rrs, ok := lookup(q.Name, q.Type); ok { + fmt.Fprintln(os.Stderr, "cache served", q.Name, q.Type) + return dnsmessage.Message{Answers: rrs}, nil + } + fmt.Fprintln(os.Stderr, "cache miss", q.Name, q.Type) + for _, ip := range next { // Aussie Broadband doesn't support IPv6 yet! if isIPv6(ip) { continue } + fmt.Fprintf(os.Stderr, "asking %s for %s %s\n", ip, q.Name, q.Type) rmsg, err = dns.Ask(q, ip2dial(ip)) if rmsg.Header.Authoritative { + fmt.Println("got auth answer") + insert(q.Name, q.Type, rmsg.Answers) + fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type) return rmsg, err } else if rmsg.Header.RCode == dnsmessage.RCodeSuccess && err == nil { break @@ -60,13 +82,24 @@ func resolve(q dnsmessage.Question, next []net.IP) (dn if err != nil { return dnsmessage.Message{}, fmt.Errorf("resolve %s: %w", q.Name, err) } + fmt.Println("no auth answer") - // no authoritative answer, so start looking for hints of who to ask next - if len(rmsg.Additionals) > 0 { - return resolve(q, nextServerAddrs(rmsg.Additionals)) + if len(rmsg.Authorities) > 0 { + if _, ok := lookup(rmsg.Authorities[0].Header.Name, rmsg.Authorities[0].Header.Type); !ok { + insert(rmsg.Authorities[0].Header.Name, rmsg.Authorities[0].Header.Type, rmsg.Authorities) + fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type) + } } + for _, a := range rmsg.Additionals { + matches := filterRRs(rmsg.Additionals, a.Header.Name, a.Header.Type) + if _, ok := lookup(a.Header.Name, a.Header.Type); !ok { + insert(a.Header.Name, a.Header.Type, matches) + fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type) + } + } - // no hints in additionals, check authorities + // get the IP addresses of the nameservers we were told about, then + // ask the same question to them if len(rmsg.Authorities) > 0 { for _, a := range rmsg.Authorities { switch b := a.Body.(type) { @@ -80,6 +113,8 @@ func resolve(q dnsmessage.Question, next []net.IP) (dn return resolve(q, nextServerAddrs(rmsg.Answers)) } return resolve(q, nextServerAddrs(rmsg.Additionals)) + default: + return rmsg, fmt.Errorf("unexpected authority resource type %s", a.Header.Type) } } } blob - /dev/null blob + aa0c6d1bcbc8c8c0cfc99c9148292378f1437e75 (mode 644) --- /dev/null +++ cmd/recursor/recursor_test.go @@ -0,0 +1,119 @@ +package main + +import ( + "fmt" + "golang.org/x/net/dns/dnsmessage" + "os" + "testing" + + "olowe.co/dns" +) + +var tquery dnsmessage.Message = dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: 69, + RecursionDesired: true, + }, + Questions: []dnsmessage.Question{ + dnsmessage.Question{ + Name: dnsmessage.MustNewName("www.example.com."), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, +} + +const testAddr string = "127.0.0.1:5359" +const quad9 string = "9.9.9.9:domain" + +func compareMsg(want, got dnsmessage.Message) error { + if want.Header != got.Header { + fmt.Errorf("mismatched headers") + } + if len(want.Answers) != len(got.Answers) { + return fmt.Errorf("mismatched answer count") + } + if (len(want.Answers) == 0 || len(got.Answers) == 0) { + return fmt.Errorf("unsupported comparison of empty answer messages") + } + wantaddr, ok := want.Answers[0].Body.(*dnsmessage.AAAAResource) + if !ok { + return fmt.Errorf("unexpected resource type from external resolver") + } + gotaddr, ok := got.Answers[0].Body.(*dnsmessage.AAAAResource) + if !ok { + return fmt.Errorf("unexpected resource type from our resolver") + } + if wantaddr.AAAA != gotaddr.AAAA { + return fmt.Errorf("wanted %s got %s", wantaddr.AAAA, gotaddr.AAAA) + } + return nil +} + +func TestMain(m *testing.M) { + go func() { + if err := dns.ListenAndServe("udp", testAddr, handler); err != nil { + fmt.Println(err) + os.Exit(1) + } + }() + os.Exit(m.Run()) +} + +func TestRecursor(t *testing.T) { + wanted, err := dns.Exchange(tquery, quad9) + if err != nil { + fmt.Fprintf(os.Stderr, "skipping %s: %v\n", t.Name(), err) + t.Skip("query internet DNS:", err) + } + got, err := dns.Exchange(tquery, testAddr) + if err != nil { + t.Fatal(err) + } + t.Logf("wanted: %+v got %+v", wanted, got) + if err := compareMsg(wanted, got); err != nil { + t.Error(err) + } + // answer should come from cache + for i := 0; i <= 1; i++ { + got, err = dns.Exchange(tquery, testAddr) + if err = compareMsg(wanted, got); err != nil { + t.Error("resolve from cache:", err) + } + } + q := tquery + q.Questions[0].Name = dnsmessage.MustNewName("www.example.net.") + for i := 0; i <= 1; i++ { + if _, err = dns.Exchange(q, testAddr); err != nil { + t.Error("resolve from cache:", err) + } + } + t.Logf("wanted: %+v got %+v", wanted, got) +} + +func TestNXDomain(t *testing.T) { + var wanted, got dnsmessage.Message + var err error + wanted, err = dns.Exchange(tquery, quad9) + if err != nil { + fmt.Fprintf(os.Stderr, "skipping %s: %v\n", t.Name(), err) + t.Skip("query internet DNS:", err) + } + q := tquery + q.Questions[0].Name = dnsmessage.MustNewName("nxdomain.example.com.") + wanted, err = dns.Exchange(q, quad9) + if err != nil { + t.Fatal(err) + } + // try twice: first for fresh response, second for cached response + for i := 0; i <= 1; i++ { + got, err = dns.Exchange(q, testAddr) + if err != nil { + t.Fatal(err) + } + if wanted.Header != got.Header { + t.Error("mismatched headers") + } + } + t.Logf("wanted: %+v got %+v", wanted, got) +}