commit - 16fb0dea8a6bd2b750ffd7f706401125f18d222a
commit + d5834ac9c24e2b94ddd01468976b08de9bb90ce8
blob - d452495bebba48c909a222b91c286b153ca7caed
blob + d494622841bcd47aeb9aabaab0545762f13ddacb
--- cmd/recursor/recursor.go
+++ cmd/recursor/recursor.go
package main
import (
- "os"
"fmt"
"golang.org/x/net/dns/dnsmessage"
- "sync"
+ "os"
"olowe.co/dns"
)
return true, dnsmessage.RCodeFormatError
} else if m.Questions[0].Type == dnsmessage.TypeALL {
return true, dnsmessage.RCodeNotImplemented
+ } else if m.Questions[0].Class != dnsmessage.ClassINET {
+ return true, dnsmessage.RCodeNotImplemented
}
return false, dnsmessage.RCodeSuccess
}
rmsg.RecursionDesired = true
q := qmsg.Questions[0]
- cache.RLock()
- if answers, ok := cache.m[q]; ok {
- rmsg.Answers = answers
- w.WriteMsg(rmsg)
- cache.RUnlock()
- fmt.Fprintf(os.Stderr, "cache served %s %s\n", q.Name, q.Type)
- return
- }
- cache.RUnlock()
-
resolved, err := resolveFromRoot(q)
if err != nil {
fmt.Fprintln(os.Stderr, err)
}
rmsg.Header.RCode = resolved.Header.RCode
rmsg.Answers = resolved.Answers
- cache.Lock()
- cache.m[q] = rmsg.Answers
- fmt.Fprintf(os.Stderr, "cached %s %s\n", q.Name, q.Type)
- cache.Unlock()
if len(rmsg.Answers) == 0 {
rmsg.Authorities = resolved.Authorities
w.WriteMsg(rmsg)
return
}
- rmsg.Answers = resolved.Answers
w.WriteMsg(rmsg)
}
-var cache = struct{
- m map[dnsmessage.Question][]dnsmessage.Resource
- sync.RWMutex
-}{m: make(map[dnsmessage.Question][]dnsmessage.Resource)}
-
func main() {
fmt.Fprintln(os.Stderr, dns.ListenAndServe("udp", "", handler))
}
blob - /dev/null
blob + f74cab14bc53ea8705412a15b159af915b68c931 (mode 644)
--- /dev/null
+++ cmd/recursor/cache.go
+package main
+
+import (
+ "sync"
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+var cache = struct {
+ m map[dnsmessage.Name]map[dnsmessage.Type][]dnsmessage.Resource
+ sync.RWMutex
+}{m: make(map[dnsmessage.Name]map[dnsmessage.Type][]dnsmessage.Resource)}
+
+func lookup(n dnsmessage.Name, t dnsmessage.Type) ([]dnsmessage.Resource, bool) {
+ cache.RLock()
+ if rr, ok := cache.m[n][t]; ok {
+ cache.RUnlock()
+ if expired(n, t) {
+ cache.Lock()
+ delete(cache.m[n], t)
+ cache.Unlock()
+ return nil, false
+ }
+ return rr, true
+ }
+ l := len(cache.m[n])
+ cache.RUnlock()
+ if l < 1 {
+ cache.Lock()
+ cache.m[n] = make(map[dnsmessage.Type][]dnsmessage.Resource)
+ cache.Unlock()
+ }
+ return nil, false
+}
+
+func insert(n dnsmessage.Name, t dnsmessage.Type, rrs []dnsmessage.Resource) {
+ cache.Lock()
+ cache.m[n][t] = rrs
+ cache.Unlock()
+ return
+}
+
+func expired(n dnsmessage.Name, t dnsmessage.Type) bool { return false }
blob - 5744b2f2b490d44a1acbcea6014b23a663aaf424
blob + 7aa4fa1052d05864c09d1a8ee9c858ab1f03d72f
--- cmd/recursor/resolve.go
+++ cmd/recursor/resolve.go
import (
"fmt"
+ "golang.org/x/net/dns/dnsmessage"
"net"
+ "os"
"strings"
- "golang.org/x/net/dns/dnsmessage"
"olowe.co/dns"
)
+const rootA = "198.41.0.4"
+const rootB = "199.9.14.201"
+const rootC = "192.33.4.12"
+const rootD = "199.7.91.13"
+const rootE = "192.203.230.10"
+
+var roots []net.IP = []net.IP{net.ParseIP(rootA), net.ParseIP(rootB), net.ParseIP(rootC)}
+
// appends the DNS port to the IP to be used in a dial string.
func ip2dial(ip net.IP) string {
return net.JoinHostPort(ip.String(), "domain")
return strings.Contains(ip.String(), ":")
}
+func filterRRs(rrs []dnsmessage.Resource, n dnsmessage.Name, t dnsmessage.Type) []dnsmessage.Resource {
+ var matches []dnsmessage.Resource
+ for _, r := range rrs {
+ if (r.Header.Name == n && r.Header.Type == t) {
+ matches = append(matches, r)
+ }
+ }
+ return matches
+}
+
func nextServerAddrs(resources []dnsmessage.Resource) []net.IP {
var next []net.IP
for _, r := range resources {
return next
}
-const rootA = "198.41.0.4"
-const rootB = "199.9.14.201"
-const rootC = "192.33.4.12"
-const rootD = "199.7.91.13"
-const rootE = "192.203.230.10"
-var roots []net.IP = []net.IP{net.ParseIP(rootA), net.ParseIP(rootB), net.ParseIP(rootC)}
-
func resolveFromRoot(q dnsmessage.Question) (dnsmessage.Message, error) {
return resolve(q, roots)
}
func resolve(q dnsmessage.Question, next []net.IP) (dnsmessage.Message, error) {
var rmsg dnsmessage.Message
var err error
+ if rrs, ok := lookup(q.Name, q.Type); ok {
+ fmt.Fprintln(os.Stderr, "cache served", q.Name, q.Type)
+ return dnsmessage.Message{Answers: rrs}, nil
+ }
+ fmt.Fprintln(os.Stderr, "cache miss", q.Name, q.Type)
+
for _, ip := range next {
// Aussie Broadband doesn't support IPv6 yet!
if isIPv6(ip) {
continue
}
+ fmt.Fprintf(os.Stderr, "asking %s for %s %s\n", ip, q.Name, q.Type)
rmsg, err = dns.Ask(q, ip2dial(ip))
if rmsg.Header.Authoritative {
+ fmt.Println("got auth answer")
+ insert(q.Name, q.Type, rmsg.Answers)
+ fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type)
return rmsg, err
} else if rmsg.Header.RCode == dnsmessage.RCodeSuccess && err == nil {
break
if err != nil {
return dnsmessage.Message{}, fmt.Errorf("resolve %s: %w", q.Name, err)
}
+ fmt.Println("no auth answer")
- // no authoritative answer, so start looking for hints of who to ask next
- if len(rmsg.Additionals) > 0 {
- return resolve(q, nextServerAddrs(rmsg.Additionals))
+ if len(rmsg.Authorities) > 0 {
+ if _, ok := lookup(rmsg.Authorities[0].Header.Name, rmsg.Authorities[0].Header.Type); !ok {
+ insert(rmsg.Authorities[0].Header.Name, rmsg.Authorities[0].Header.Type, rmsg.Authorities)
+ fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type)
+ }
}
+ for _, a := range rmsg.Additionals {
+ matches := filterRRs(rmsg.Additionals, a.Header.Name, a.Header.Type)
+ if _, ok := lookup(a.Header.Name, a.Header.Type); !ok {
+ insert(a.Header.Name, a.Header.Type, matches)
+ fmt.Fprintln(os.Stderr, "cached", q.Name, q.Type)
+ }
+ }
- // no hints in additionals, check authorities
+ // get the IP addresses of the nameservers we were told about, then
+ // ask the same question to them
if len(rmsg.Authorities) > 0 {
for _, a := range rmsg.Authorities {
switch b := a.Body.(type) {
return resolve(q, nextServerAddrs(rmsg.Answers))
}
return resolve(q, nextServerAddrs(rmsg.Additionals))
+ default:
+ return rmsg, fmt.Errorf("unexpected authority resource type %s", a.Header.Type)
}
}
}
blob - /dev/null
blob + aa0c6d1bcbc8c8c0cfc99c9148292378f1437e75 (mode 644)
--- /dev/null
+++ cmd/recursor/recursor_test.go
+package main
+
+import (
+ "fmt"
+ "golang.org/x/net/dns/dnsmessage"
+ "os"
+ "testing"
+
+ "olowe.co/dns"
+)
+
+var tquery dnsmessage.Message = dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: 69,
+ RecursionDesired: true,
+ },
+ Questions: []dnsmessage.Question{
+ dnsmessage.Question{
+ Name: dnsmessage.MustNewName("www.example.com."),
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ },
+ },
+}
+
+const testAddr string = "127.0.0.1:5359"
+const quad9 string = "9.9.9.9:domain"
+
+func compareMsg(want, got dnsmessage.Message) error {
+ if want.Header != got.Header {
+ fmt.Errorf("mismatched headers")
+ }
+ if len(want.Answers) != len(got.Answers) {
+ return fmt.Errorf("mismatched answer count")
+ }
+ if (len(want.Answers) == 0 || len(got.Answers) == 0) {
+ return fmt.Errorf("unsupported comparison of empty answer messages")
+ }
+ wantaddr, ok := want.Answers[0].Body.(*dnsmessage.AAAAResource)
+ if !ok {
+ return fmt.Errorf("unexpected resource type from external resolver")
+ }
+ gotaddr, ok := got.Answers[0].Body.(*dnsmessage.AAAAResource)
+ if !ok {
+ return fmt.Errorf("unexpected resource type from our resolver")
+ }
+ if wantaddr.AAAA != gotaddr.AAAA {
+ return fmt.Errorf("wanted %s got %s", wantaddr.AAAA, gotaddr.AAAA)
+ }
+ return nil
+}
+
+func TestMain(m *testing.M) {
+ go func() {
+ if err := dns.ListenAndServe("udp", testAddr, handler); err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+ }()
+ os.Exit(m.Run())
+}
+
+func TestRecursor(t *testing.T) {
+ wanted, err := dns.Exchange(tquery, quad9)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "skipping %s: %v\n", t.Name(), err)
+ t.Skip("query internet DNS:", err)
+ }
+ got, err := dns.Exchange(tquery, testAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("wanted: %+v got %+v", wanted, got)
+ if err := compareMsg(wanted, got); err != nil {
+ t.Error(err)
+ }
+ // answer should come from cache
+ for i := 0; i <= 1; i++ {
+ got, err = dns.Exchange(tquery, testAddr)
+ if err = compareMsg(wanted, got); err != nil {
+ t.Error("resolve from cache:", err)
+ }
+ }
+ q := tquery
+ q.Questions[0].Name = dnsmessage.MustNewName("www.example.net.")
+ for i := 0; i <= 1; i++ {
+ if _, err = dns.Exchange(q, testAddr); err != nil {
+ t.Error("resolve from cache:", err)
+ }
+ }
+ t.Logf("wanted: %+v got %+v", wanted, got)
+}
+
+func TestNXDomain(t *testing.T) {
+ var wanted, got dnsmessage.Message
+ var err error
+ wanted, err = dns.Exchange(tquery, quad9)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "skipping %s: %v\n", t.Name(), err)
+ t.Skip("query internet DNS:", err)
+ }
+ q := tquery
+ q.Questions[0].Name = dnsmessage.MustNewName("nxdomain.example.com.")
+ wanted, err = dns.Exchange(q, quad9)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // try twice: first for fresh response, second for cached response
+ for i := 0; i <= 1; i++ {
+ got, err = dns.Exchange(q, testAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if wanted.Header != got.Header {
+ t.Error("mismatched headers")
+ }
+ }
+ t.Logf("wanted: %+v got %+v", wanted, got)
+}