Commit Diff


commit - 1bcf9fa40865b1840d67029d990dcbfe61a43a23
commit + 66267fa01f36cd99f4b3ac53250a1b8cbd3e2568
blob - 0c66051917400b8253ab00299af5f82fa2c73de1
blob + 7beb094d71c911951b0ae0230ba101c8b113916b
--- dns.go
+++ dns.go
@@ -1,6 +1,10 @@
 package dns
 
 import (
+	"crypto/tls"
+	"encoding/binary"
+	"fmt"
+	"io"
 	"net"
 
 	"golang.org/x/net/dns/dnsmessage"
@@ -21,22 +25,53 @@ func Exchange(msg dnsmessage.Message, addr string) (dn
 	return send(msg, conn)
 }
 
+// ExchangeTLS performs a synchronous DNS-over-TLS exchange with addr and returns its
+// reply to msg.
+func ExchangeTLS(msg dnsmessage.Message, addr string) (dnsmessage.Message, error) {
+	conn, err := tls.Dial("tcp", addr, nil)
+	if err != nil {
+		return dnsmessage.Message{}, err
+	}
+	defer conn.Close()
+	return send(msg, conn)
+}
+
 func send(msg dnsmessage.Message, conn net.Conn) (dnsmessage.Message, error) {
 	packed, err := msg.Pack()
 	if err != nil {
 		return dnsmessage.Message{}, err
 	}
-	if _, err := conn.Write(packed); err != nil {
-		return dnsmessage.Message{}, err
+	if _, ok := conn.(net.PacketConn); ok {
+		if _, err = conn.Write(packed); err != nil {
+			return dnsmessage.Message{}, err
+		}
+	} else {
+		// DNS over TCP requires you to prepend the message with a
+		// 2-octet length field.
+		m := make([]byte, 2+len(packed))
+		binary.BigEndian.PutUint16(m, uint16(len(packed)))
+		copy(m[2:], packed)
+		if _, err = conn.Write(m); err != nil {
+			return dnsmessage.Message{}, err
+		}
 	}
 	buf := make([]byte, 1024)
 	n, err := conn.Read(buf)
-	if err != nil {
+	if err != nil && err != io.EOF {
 		return dnsmessage.Message{}, err
 	}
+	if n == 0 {
+		return dnsmessage.Message{}, fmt.Errorf("empty response")
+	}
 	var rmsg dnsmessage.Message
-	if err := rmsg.Unpack(buf[:n]); err != nil {
-		return dnsmessage.Message{}, err
+	if _, ok := conn.(net.PacketConn); ok {
+		if err := rmsg.Unpack(buf[:n]); err != nil {
+			return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
+		}
+	} else {
+		if err := rmsg.Unpack(buf[2:n]); err != nil {
+			return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
+		}
 	}
 	return rmsg, nil
 }
blob - da468ca44cc7be238bbfdf9b228b4be2b8db9e3b
blob + 9977707d9918bae43f658e942b724edd1fdeadf8
--- go.mod
+++ go.mod
@@ -3,7 +3,8 @@ module git.sr.ht/~otl/dns
 go 1.17
 
 require (
-	golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa // indirect
-	golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
-	golang.org/x/text v0.3.3 // indirect
+	golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
+	golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
 )
+
+require golang.org/x/text v0.3.3 // indirect
blob - 64a51352e01b4bfa8f3b38c539eeac8143307c91
blob + 03e3e12723bb9cf3f4eeb512bcc4498fc09782a2
--- go.sum
+++ go.sum
@@ -3,6 +3,7 @@ golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=