Commit Diff


commit - b3a0993e1f1aff1ad84b2e0cd680e5697bfd17dc
commit + 88787149cf89ee12e50156d2be193b2194e4052f
blob - /dev/null
blob + 072980388774b74cd3966553993ba1a2e4bd5e60 (mode 644)
--- /dev/null
+++ cmd/dohproxy/config.go
@@ -0,0 +1,47 @@
+package main
+
+import (
+	"bufio"
+	"fmt"
+	"io"
+	"os"
+	"strings"
+)
+
+type config struct {
+	forwardaddr string
+	listenaddr  string
+}
+
+func configFromFile(name string) (config, error) {
+	f, err := os.Open(name)
+	if err != nil {
+		return config{}, err
+	}
+	defer f.Close()
+	return parseConfig(f)
+}
+
+func parseConfig(r io.Reader) (config, error) {
+	sc := bufio.NewScanner(r)
+	var c config
+	for sc.Scan() {
+		line := strings.TrimSpace(sc.Text())
+		if strings.HasPrefix(line, "#") {
+			continue // skip config comments
+		}
+		fields := strings.Fields(line)
+		if len(fields) > 2 {
+			return c, fmt.Errorf("too many values for key %s", fields[0])
+		}
+		switch k := fields[0]; k {
+		case "listen":
+			c.listenaddr = fields[1]
+		case "forward":
+			c.forwardaddr = fields[1]
+		default:
+			return c, fmt.Errorf("unknown key %s", k)
+		}
+	}
+	return c, nil
+}
blob - 7beb094d71c911951b0ae0230ba101c8b113916b
blob + 3be6cfeb1aa70596289c0eaa957026708b346bea
--- dns.go
+++ dns.go
@@ -4,7 +4,6 @@ import (
 	"crypto/tls"
 	"encoding/binary"
 	"fmt"
-	"io"
 	"net"
 
 	"golang.org/x/net/dns/dnsmessage"
@@ -42,36 +41,50 @@ func send(msg dnsmessage.Message, conn net.Conn) (dnsm
 		return dnsmessage.Message{}, err
 	}
 	if _, ok := conn.(net.PacketConn); ok {
-		if _, err = conn.Write(packed); err != nil {
-			return dnsmessage.Message{}, err
+		b, err := dnsPacketExchange(packed, conn)
+		if err != nil {
+			return dnsmessage.Message{}, fmt.Errorf("exchange DNS packet: %v", err)
 		}
 	} else {
-		// DNS over TCP requires you to prepend the message with a
-		// 2-octet length field.
-		m := make([]byte, 2+len(packed))
-		binary.BigEndian.PutUint16(m, uint16(len(packed)))
-		copy(m[2:], packed)
-		if _, err = conn.Write(m); err != nil {
-			return dnsmessage.Message{}, err
-		}
+		b, err := dnsStreamExchange(packed, conn)
+		if err != nil {
+			return dnsmessage.Message{}, fmt.Errorf("exchange DNS TCP stream: %v", err)
 	}
-	buf := make([]byte, 1024)
-	n, err := conn.Read(buf)
-	if err != nil && err != io.EOF {
-		return dnsmessage.Message{}, err
-	}
-	if n == 0 {
-		return dnsmessage.Message{}, fmt.Errorf("empty response")
-	}
 	var rmsg dnsmessage.Message
-	if _, ok := conn.(net.PacketConn); ok {
-		if err := rmsg.Unpack(buf[:n]); err != nil {
-			return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
-		}
-	} else {
-		if err := rmsg.Unpack(buf[2:n]); err != nil {
-			return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
-		}
+	if err := rmsg.Unpack(b); err != nil {
+		return dnsmessage.Message{}, fmt.Errorf("parse response: %v", err)
 	}
 	return rmsg, nil
 }
+
+func dnsPacketExchange(b []byte, conn net.Conn) ([]byte, error) {
+	if _, err := conn.Write(b); err != nil {
+		return nil, err
+	}
+	buf := make([]byte, 512) // max UDP size per RFC?
+	n, err := conn.Read(buf)
+	if err != nil {
+		return nil, err
+	}
+	return buf[:n], nil
+}
+
+func dnsStreamExchange(b []byte, conn net.Conn) ([]byte, error) {
+	// DNS over TCP requires you to prepend the message with a
+	// 2-octet length field.
+	m := make([]byte, 2+len(b))
+	binary.BigEndian.PutUint16(m, uint16(len(b)))
+	copy(m[2:], b)
+	if _, err := conn.Write(m); err != nil {
+		return nil, err
+	}
+	buf := make([]byte, 1024)
+	n, err := conn.Read(buf)
+	if err != nil {
+		return nil, err
+	}
+	if n == 0 {
+		return nil, fmt.Errorf("empty response")
+	}
+	return buf[2:n], nil
+}