commit ded9109ce8770cd39b01c1fb5a8afdc1bd1fc335 from: Oliver Lowe date: Thu Dec 16 03:22:12 2021 UTC New Ask functions These let us specify just a question without needing to worry about the entire DNS message structure. commit - 0be319b95e9008dcc63694c16025f23e90e74300 commit + ded9109ce8770cd39b01c1fb5a8afdc1bd1fc335 blob - d278a355c3d7d6eb09a0cb152f43504fc9f50f13 blob + 4b1c81b5517c84971733bfca0c64fb78887e973e --- cmd/recursor/recursor.go +++ cmd/recursor/recursor.go @@ -6,8 +6,6 @@ import ( "net" "strings" "sync" - "time" - "math/rand" "golang.org/x/net/dns/dnsmessage" "olowe.co/dns" ) @@ -28,10 +26,6 @@ func ip2dial(ip net.IP) string { return net.JoinHostPort(ip.String(), "domain") } -func newID() uint16 { - return uint16(rand.Intn(65535)) -} - func nextServerAddrs(resources []dnsmessage.Resource) []net.IP { var next []net.IP for _, r := range resources { @@ -46,10 +40,6 @@ func nextServerAddrs(resources []dnsmessage.Resource) } func resolve(q dnsmessage.Question, next []net.IP) (dnsmessage.Message, error) { - qmsg := dnsmessage.Message{ - Header: dnsmessage.Header{ID: newID()}, - Questions: []dnsmessage.Question{q}, - } var rmsg dnsmessage.Message var err error for _, ip := range next { @@ -58,7 +48,7 @@ func resolve(q dnsmessage.Question, next []net.IP) (dn continue } fmt.Fprintf(os.Stderr, "asking %s about %s\n", ip, q.Name) - rmsg, err = dns.Exchange(qmsg, ip2dial(ip)) + rmsg, err = dns.Ask(q, ip2dial(ip)) if rmsg.Header.Authoritative { return rmsg, err } else if rmsg.Header.RCode == dnsmessage.RCodeSuccess && err == nil { @@ -134,7 +124,7 @@ func handler(w dns.ResponseWriter, qmsg *dnsmessage.Me fmt.Fprintf(os.Stderr, "cache served %s %s\n", q.Name, q.Type) return } - cache.RUnlock() + cache.RUnlock() resolved, err := resolveFromRoot(q) if err != nil { @@ -164,6 +154,5 @@ var cache = struct{ }{m: make(map[dnsmessage.Question][]dnsmessage.Resource)} func main() { - rand.Seed(time.Now().UnixNano()) fmt.Fprintln(os.Stderr, dns.ListenAndServe("udp", "", handler)) } blob - 40080dee87ebf41f36afab6100cf11604559119f blob + 42e892fbcf995c094bec224548ebda83d2efd8a8 --- dns.go +++ dns.go @@ -71,7 +71,9 @@ import ( "errors" "fmt" "io" + "math/rand" "net" + "time" "golang.org/x/net/dns/dnsmessage" ) @@ -82,6 +84,42 @@ const MaxMsgSize int = 65535 // max size of a message var errMismatchedID = errors.New("mismatched message id") +var randomsrc *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +func newID() uint16 { + return uint16(randomsrc.Intn(65535)) +} + +// Ask sends a message with q to addr and returns its response. +// The exchange is unencrypted using UDP. +func Ask(q dnsmessage.Question, addr string) (dnsmessage.Message, error) { + qmsg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: newID()}, + Questions: []dnsmessage.Question{q}, + } + return Exchange(qmsg, addr) +} + +// Ask sends a message with q to addr and returns its response. +// The exchange is unencrypted using TCP. +func AskTCP(q dnsmessage.Question, addr string) (dnsmessage.Message, error) { + qmsg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: newID()}, + Questions: []dnsmessage.Question{q}, + } + return ExchangeTCP(qmsg, addr) +} + +// Ask sends a message with q to addr and returns its response. +// The exchange is encrypted using DNS over TLS. +func AskTLS(q dnsmessage.Question, addr string) (dnsmessage.Message, error) { + qmsg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: newID()}, + Questions: []dnsmessage.Question{q}, + } + return ExchangeTLS(qmsg, addr) +} + // Exchange performs a synchronous, unencrypted UDP DNS exchange with addr and returns its // reply to msg. func Exchange(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) { blob - 40e6d3d9fe3e29f4aa1a11c04f9868ba925b4a7f blob + 6d63e59a3f88930b69912f9553f3149377ed7bcc --- server_test.go +++ server_test.go @@ -1,6 +1,7 @@ package dns import ( + "golang.org/x/net/dns/dnsmessage" "testing" ) @@ -8,12 +9,9 @@ func TestServer(t *testing.T) { go func() { t.Fatal(ListenAndServe("udp", "127.0.0.1:51111", nil)) }() - q, err := buildmsg("www.example.com.") + q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} + rmsg, err := Ask(q, "127.0.0.1:51111") if err != nil { - t.Fatalf("create query: %v", err) - } - rmsg, err := Exchange(q, "127.0.0.1:51111") - if err != nil { t.Errorf("exchange: %v", err) } t.Log("response:", rmsg) @@ -23,12 +21,9 @@ func TestStreamServer(t *testing.T) { go func() { t.Fatal(ListenAndServe("tcp", "127.0.0.1:51112", nil)) }() - q, err := buildmsg("www.example.com.") + q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} + rmsg, err := AskTCP(q, "127.0.0.1:51112") if err != nil { - t.Fatal("create query:", err) - } - rmsg, err := ExchangeTCP(q, "127.0.0.1:51112") - if err != nil { t.Errorf("exchange: %v", err) } t.Log("response:", rmsg) @@ -40,12 +35,9 @@ func TestEmptyServer(t *testing.T) { t.Fatal(srv.ListenAndServe()) t.Log(srv.addr) }() - q, err := buildmsg("www.example.com.") + q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET} + rmsg, err := Ask(q, "127.0.0.1:domain") if err != nil { - t.Fatal("create query:", err) - } - rmsg, err := Exchange(q, "127.0.0.1:domain") - if err != nil { t.Errorf("exchange: %v", err) } t.Log("response:", rmsg)