Commit Diff


commit - c90cf61f8fabcb90ab2a56309fb0697b71314698
commit + 3c37ea2157d302402d0c4611b09270185e028290
blob - 95282373aa5fe50714181c3824064e1c35071d36
blob + 6c07dd06521d1729753dfba313f0ee1e9f1462ae
--- client/client.go
+++ client/client.go
@@ -2,8 +2,11 @@
 package nntpclient
 
 import (
+	"bytes"
+	"crypto/tls"
 	"errors"
 	"io"
+	"net"
 	"net/textproto"
 	"strconv"
 	"strings"
@@ -14,25 +17,49 @@ import (
 // Client is an NNTP client.
 type Client struct {
 	conn   *textproto.Conn
+	netconn net.Conn
+	tls bool
 	Banner string
+	capabilities []string
 }
 
 // New connects a client to an NNTP server.
-func New(net, addr string) (*Client, error) {
-	conn, err := textproto.Dial(net, addr)
+func New(network, addr string) (*Client, error) {
+	netconn, err := net.Dial(network, addr)
 	if err != nil {
 		return nil, err
 	}
-
-	return connect(conn)
+	return connect(netconn)
 }
 
 // NewConn wraps an existing connection, for example one opened with tls.Dial
-func NewConn(conn io.ReadWriteCloser) (*Client, error) {
-	return connect(textproto.NewConn(conn))
+func NewConn(netconn net.Conn) (*Client, error) {
+	client, err := connect(netconn)
+	if err != nil {
+		return nil, err
+	}
+	if _, ok := netconn.(*tls.Conn); ok {
+		client.tls = true
+	}
+	return client, nil
 }
 
-func connect(conn *textproto.Conn) (*Client, error) {
+// NewTLS connects to an NNTP server over a dedicated TLS port like 563
+func NewTLS(network, addr string, config *tls.Config) (*Client, error) {
+	netconn, err := tls.Dial(network, addr, config)
+	if err != nil {
+		return nil, err
+	}
+	client, err := connect(netconn)
+	if err != nil {
+		return nil, err
+	}
+	client.tls = true
+	return client, nil
+}
+
+func connect(netconn net.Conn) (*Client, error) {
+	conn := textproto.NewConn(netconn)
 	_, msg, err := conn.ReadCodeLine(200)
 	if err != nil {
 		return nil, err
@@ -40,6 +67,7 @@ func connect(conn *textproto.Conn) (*Client, error) {
 
 	return &Client{
 		conn:   conn,
+		netconn: netconn,
 		Banner: msg,
 	}, nil
 }
@@ -213,3 +241,126 @@ func (c *Client) Command(cmd string, expectCode int) (
 	}
 	return c.conn.ReadCodeLine(expectCode)
 }
+
+// Capabilities retrieves a list of supported caps.
+//
+// See https://datatracker.ietf.org/doc/html/rfc3977#section-5.2.2
+func (c *Client) Capabilities() ([]string, error) {
+	err := c.conn.PrintfLine("CAPABILITIES")
+	if err != nil {
+		return nil, err
+	}
+	_, _, err = c.conn.ReadCodeLine(101)
+	if err != nil {
+		return nil, err
+	}
+	b, err := io.ReadAll(c.conn.DotReader())
+	if err != nil {
+		return nil, err
+	}
+	caps := strings.Split(string(bytes.TrimSpace(b)), "\n")
+	c.capabilities = caps
+	return caps, nil
+}
+
+// GetCapability returns a complete capbility line.
+//
+// See https://datatracker.ietf.org/doc/html/rfc3977#section-9.5
+func (c *Client) GetCapability(capability string) string {
+	for _, capa := range c.capabilities {
+		if strings.SplitN(capa, " ", 2)[0] == capability {
+			return capa
+		}
+	}
+	return ""
+}
+
+// HasCapabilityArgument indicates whether a capability arg is supported.
+//
+// See https://datatracker.ietf.org/doc/html/rfc3977#section-9.5
+func (c *Client) HasCapabilityArgument(
+	capability, argument string,
+) (bool, error) {
+	if c.capabilities == nil {
+		return false, errors.New("Capabilities unpopulated")
+	}
+	capLine := c.GetCapability(capability)
+	if capLine == "" {
+		return false, errors.New("No such capability")
+	}
+	for _, capArg := range strings.Split(capLine, " ") {
+		if capArg == argument {
+			return true, nil
+		}
+	}
+	return false, nil
+}
+
+// ListOverviewFmt performs a LIST OVERVIEW.FMT query.
+//
+// According to the spec, the presence of an "OVER" line in the capabilities
+// response means this LIST variant is supported, so there's no reason to
+// check for it among the keywords in the LIST line, strictly speaking.
+//
+// See https://datatracker.ietf.org/doc/html/rfc3977#section-3.3.2
+func (c *Client) ListOverviewFmt() ([]string, error) {
+	err := c.conn.PrintfLine("LIST OVERVIEW.FMT")
+	if err != nil {
+		return nil, err
+	}
+	_, _, err = c.conn.ReadCodeLine(215)
+	if err != nil {
+		return nil, err
+	}
+	b, err := io.ReadAll(c.conn.DotReader())
+	if err != nil {
+		return nil, err
+	}
+	fields := strings.Split(string(bytes.TrimSpace(b)), "\n")
+	return fields, nil
+}
+
+// Over returns a list of raw overview lines with tab-separated fields.
+func (c *Client) Over(specifier string) ([]string, error) {
+	err := c.conn.PrintfLine("OVER %s", specifier)
+	if err != nil {
+		return nil, err
+	}
+	_, _, err = c.conn.ReadCodeLine(224)
+	if err != nil {
+		return nil, err
+	}
+	b, err := io.ReadAll(c.conn.DotReader())
+	if err != nil {
+		return nil, err
+	}
+	lines := strings.Split(string(bytes.TrimSpace(b)), "\n")
+	return lines, nil
+}
+
+func (c *Client) HasTLS() bool {
+	return c.tls
+}
+
+// StartTLS sends the STARTTLS command and refreshes capabilities.
+//
+// See https://datatracker.ietf.org/doc/html/rfc4642 and net/smtp.go, from
+// which this was adapted, and maybe NNTP.startls in Python's nntplib also.
+func (c *Client) StartTLS(config *tls.Config) error {
+	if c.tls {
+		return errors.New("TLS already active")
+	}
+	err := c.conn.PrintfLine("STARTTLS")
+	_, _, err = c.conn.ReadCodeLine(382)
+	if err != nil {
+		return err
+	}
+	c.netconn = tls.Client(c.netconn, config)
+	c.conn = textproto.NewConn(c.netconn)
+	c.tls = true
+	_, err = c.Capabilities()
+	if err != nil {
+		return err
+	}
+	return nil
+}