commit - 20fd186cb05aed778db0ee7d84c3e413e3ad68fe
commit + 7e5b66014d8e51608b3c909ab643e1f91f12b8db
blob - b00a41244c0c16699f6f406da8c7ef89533d306d
blob + da47d0f5ade635b9e5992f18e81ff88699f52c1c
--- dns_test.go
+++ dns_test.go
package dns
import (
- "errors"
- "fmt"
- "io"
"math/rand"
"net"
"testing"
- "time"
"golang.org/x/net/dns/dnsmessage"
)
-type fakeDNSConn struct {
- net.Conn
- server fakeDNSServer
- buf []byte
- tcp bool
-}
+var testq = dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
-type fakeDNSPacketConn struct {
- net.PacketConn
- fakeDNSConn
+func resolveBadly(w ResponseWriter, qmsg *dnsmessage.Message) {
+ rmsg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: qmsg.Header.ID + 69,
+ Response: false,
+ RCode: dnsmessage.RCodeNameError,
+ },
+ Questions: qmsg.Questions,
+ }
+ w.WriteMsg(rmsg)
}
-func (f *fakeDNSPacketConn) Close() error {
- return nil
-}
-
-type fakeDNSServer struct {
- resolve func(q dnsmessage.Message) (dnsmessage.Message, error)
-}
-
-func resolveWell(q dnsmessage.Message) (dnsmessage.Message, error) {
- return dnsmessage.Message{
+func resolveWrongQuestion(w ResponseWriter, qmsg *dnsmessage.Message) {
+ wrongq := dnsmessage.Question{Name: dnsmessage.MustNewName("blabla.example.org."), Type: dnsmessage.TypeNS, Class: dnsmessage.ClassCHAOS}
+ rmsg := dnsmessage.Message{
Header: dnsmessage.Header{
- ID: q.Header.ID,
+ ID: qmsg.Header.ID,
Response: true,
RCode: dnsmessage.RCodeSuccess,
+ Authoritative: true,
},
- Questions: q.Questions,
- Answers: []dnsmessage.Resource{
- {
- Header: dnsmessage.ResourceHeader{
- Name: q.Questions[0].Name,
- Type: q.Questions[0].Type,
- Class: q.Questions[0].Class,
- },
- Body: &dnsmessage.AResource{
- A: [4]byte{0xc0, 0x00, 0x02, 0x01},
- },
- },
- },
- }, nil
-}
-
-var errCrashed = errors.New("crashed")
-
-func resolveBadly(q dnsmessage.Message) (dnsmessage.Message, error) {
- return dnsmessage.Message{}, errCrashed
-}
-
-func (f fakeDNSConn) Close() error {
- return nil
-}
-
-func (f *fakeDNSConn) Write(b []byte) (int, error) {
- time.Sleep(50 * time.Millisecond)
- if len(f.buf) > 0 {
- return 0, fmt.Errorf("connection buffer full, refusing overwrite")
+ Questions: []dnsmessage.Question{wrongq},
}
- var qmsg dnsmessage.Message
- if f.tcp {
- if err := qmsg.Unpack(b[2:]); err != nil {
- return len(b), err
- }
- } else {
- if err := qmsg.Unpack(b); err != nil {
- return len(b), err
- }
- }
- rmsg, err := f.server.resolve(qmsg)
- if err != nil {
- return len(b), err
- }
- packed, err := rmsg.Pack()
- if err != nil {
- return len(b), err
- }
- if f.tcp {
- l := len(packed)
- buf := make([]byte, 2+len(packed))
- buf[0] = byte(l >> 8)
- buf[1] = byte(l)
- copy(buf[2:], packed)
- f.buf = buf
- return len(b), nil
- }
- f.buf = packed
- return len(b), err
+ w.WriteMsg(rmsg)
}
-func (f *fakeDNSConn) Read(b []byte) (int, error) {
- if len(f.buf) > 0 {
- n := copy(b, f.buf)
- f.buf = f.buf[n:]
- return n, nil
- }
- return 0, io.EOF
-}
-
-func TestGoodConn(t *testing.T) {
- qmsg, err := buildmsg("www.example.com.")
+func TestBadResolver(t *testing.T) {
+ srv := Server{network: "udp", addr: "127.0.0.1", Handler: resolveBadly}
+ conn, err := net.ListenPacket("udp", "127.0.0.1:5359")
if err != nil {
t.Fatal(err)
}
- var goodconn fakeDNSConn
- goodconn.server.resolve = resolveWell
- goodconn.tcp = true
- _, err = exchange(qmsg, &goodconn)
- if err != nil {
- t.Error(err)
+ go func() {
+ t.Fatal(srv.ServePacket(conn))
+ }()
+ rmsg, err := Ask(testq, "127.0.0.1:5359")
+ if err == nil {
+ t.Error("wanted error, got nil")
}
-}
+ t.Log(err)
+ t.Log("sent:", testq, "received", rmsg)
-func TestShitConn(t *testing.T) {
- qmsg, err := buildmsg("www.example.com.")
- if err != nil {
- t.Fatal(err)
- }
- var shitconn fakeDNSConn
- shitconn.server.resolve = resolveBadly
- shitconn.tcp = true
- _, err = exchange(qmsg, &shitconn)
- if !errors.Is(err, errCrashed) {
- t.Errorf("wanted error %v, got %v", errCrashed, err)
- }
-}
-
-func TestBadMessage(t *testing.T) {
- q, err := buildmsg("www.example.com.")
- if err != nil {
- t.Fatal(err)
- }
- var shitconn fakeDNSPacketConn
- shitconn.server.resolve = func(q dnsmessage.Message) (dnsmessage.Message, error) {
- return dnsmessage.Message{
- Header: dnsmessage.Header{
- ID: q.Header.ID + 69,
- Response: false,
- RCode: dnsmessage.RCodeNameError,
- },
- Questions: q.Questions,
- }, nil
- }
- r, err := exchange(q, &shitconn)
- if !errors.Is(err, errMismatchedID) {
+ srv.Handler = resolveWrongQuestion
+ rmsg, err = Ask(testq, "127.0.0.1:5359")
+ if err == nil {
+ t.Error("wanted error, got nil")
+ } else if err != nil {
t.Log(err)
- t.Errorf("should error on receiving mismatched message IDs; sent %d, received %d", q.Header.ID, r.Header.ID)
}
+ t.Log("sent:", testq, "received:", rmsg)
}
func buildmsg(s string) (dnsmessage.Message, error) {
blob - 6d63e59a3f88930b69912f9553f3149377ed7bcc
blob + a505e6c00f2e448077ec690c4b7b13318e4ec389
--- server_test.go
+++ server_test.go
package dns
import (
- "golang.org/x/net/dns/dnsmessage"
+ "crypto/rand"
+ "io"
+ "net"
+ "time"
"testing"
)
go func() {
t.Fatal(ListenAndServe("udp", "127.0.0.1:51111", nil))
}()
- q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
- rmsg, err := Ask(q, "127.0.0.1:51111")
+ time.Sleep(time.Millisecond)
+ rmsg, err := Ask(testq, "127.0.0.1:51111")
if err != nil {
t.Errorf("exchange: %v", err)
}
go func() {
t.Fatal(ListenAndServe("tcp", "127.0.0.1:51112", nil))
}()
- q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
- rmsg, err := AskTCP(q, "127.0.0.1:51112")
+ time.Sleep(time.Millisecond)
+ rmsg, err := AskTCP(testq, "127.0.0.1:51112")
if err != nil {
t.Errorf("exchange: %v", err)
}
t.Fatal(srv.ListenAndServe())
t.Log(srv.addr)
}()
- q := dnsmessage.Question{Name: dnsmessage.MustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}
- rmsg, err := Ask(q, "127.0.0.1:domain")
+ rmsg, err := Ask(testq, "127.0.0.1:domain")
if err != nil {
t.Errorf("exchange: %v", err)
}
t.Log("response:", rmsg)
}
+
+func TestJunk(t *testing.T) {
+ addr := "127.0.0.1:5361"
+ go func() {
+ t.Fatal(ListenAndServe("tcp", addr, nil))
+ }()
+ time.Sleep(time.Millisecond)
+ for i := 0; i <= 30; i++ {
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ if _, err := io.CopyN(conn, rand.Reader, 8192); err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkPacketVsStream(b *testing.B) {
+ addr := "127.0.0.1:51113"
+ var networks = []string{"udp", "tcp"}
+ for _, net := range networks {
+ go func(){
+ b.Fatal(ListenAndServe(net, addr, nil))
+ }()
+ b.Run(net, func(b *testing.B) {
+ for i := 0; i<= b.N; i++ {
+ if net == "udp" {
+ if rmsg, err := Ask(testq, addr); err != nil {
+ b.Log(rmsg)
+ b.Fatal(err)
+ }
+ } else {
+ if rmsg, err := AskTCP(testq, addr); err != nil {
+ b.Log(rmsg)
+ b.Fatal(err)
+ }
+ }
+ }
+ })
+ }
+}