Commit Diff


commit - 74566e714d2448293a6c527f8cb2c0c7abafabf9
commit + 7d5f9562c75185f9fe13bbc326577dc860146124
blob - /dev/null
blob + 66acf4ecc4c76473fd24bbd22b6b548511520c4e (mode 644)
--- /dev/null
+++ src/llm/llm.go
@@ -0,0 +1,106 @@
+package main
+
+import (
+	"bufio"
+	"bytes"
+	"errors"
+	"flag"
+	"fmt"
+	"io"
+	"log"
+	"net/http"
+	"os"
+	"path"
+
+	"olowe.co/x/openai"
+)
+
+var model = flag.String("m", "ministral-8b-latest", "model")
+var baseURL = flag.String("u", "https://api.mistral.ai", "openai API base URL")
+var sysPrompt = flag.String("s", "You are a helpful assistant.", "system prompt")
+var converse = flag.Bool("c", false, "start a back-and-forth chat")
+
+func readToken() (string, error) {
+	confDir, err := os.UserConfigDir()
+	if err != nil {
+		return "", err
+	}
+	b, err := os.ReadFile(path.Join(confDir, "openai/token"))
+	return string(bytes.TrimSpace(b)), err
+}
+
+func copyAll(w io.Writer, paths []string) (n int64, err error) {
+	if len(paths) == 0 {
+		return io.Copy(w, os.Stdin)
+	}
+	var errs []error
+	for _, name := range paths {
+		f, err := os.Open(name)
+		if err != nil {
+			return n, err
+		}
+		nn, err := io.Copy(w, f)
+		if err != nil {
+			errs = append(errs, fmt.Errorf("copy %s: %w", name, err))
+		}
+		n += nn
+	}
+	return n, errors.Join(errs...)
+}
+
+func init() {
+	log.SetFlags(0)
+	log.SetPrefix("llm: ")
+	flag.Parse()
+}
+
+func main() {
+	token, err := readToken()
+	if err != nil {
+		log.Fatalf("read auth token: %v", err)
+	}
+	client := &openai.Client{http.DefaultClient, token, *baseURL}
+
+	chat := openai.Chat{
+		Messages: []openai.Message{
+			{openai.RoleSystem, *sysPrompt},
+		},
+		Model: *model,
+	}
+	buf := &bytes.Buffer{}
+	if !*converse {
+		_, err := copyAll(buf, flag.Args())
+		if err != nil {
+			log.Fatalln("construct prompt:", err)
+		}
+		msg := openai.Message{openai.RoleUser, buf.String()}
+		chat.Messages = append(chat.Messages, msg)
+		reply, err := client.Complete(&chat)
+		if err != nil {
+			log.Fatalln("llm complete:", err)
+		}
+		fmt.Println(reply.Content)
+		return
+	}
+
+	sc := bufio.NewScanner(os.Stdin)
+	if len(flag.Args()) > 0 {
+		log.Println("conversation mode, ignoring arguments")
+	}
+	for sc.Scan() {
+		if sc.Text() == "." {
+			msg := openai.Message{openai.RoleUser, buf.String()}
+			chat.Messages = append(chat.Messages, msg)
+			reply, err := client.Complete(&chat)
+			if err != nil {
+				fmt.Fprintln(os.Stderr, "chat not completed:", err)
+				continue // try again, allowing a retry with another "." line
+			}
+			buf.Reset()
+			fmt.Println(reply.Content)
+			chat.Messages = append(chat.Messages, *reply)
+			continue
+		}
+		fmt.Fprintln(buf, sc.Text())
+	}
+}
blob - /dev/null
blob + 50086755e5e95a4d364613b0e3e8358e87751979 (mode 644)
--- /dev/null
+++ src/openai/openai.go
@@ -0,0 +1,108 @@
+package openai
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"strings"
+)
+
+type Role string
+
+const (
+	RoleSystem    Role = "system"
+	RoleUser      Role = "user"
+	RoleAssistant      = "assistant"
+)
+
+type Message struct {
+	Role    Role   `json:"role"`
+	Content string `json:"content"`
+}
+
+type Chat struct {
+	Messages       []Message `json:"messages"`
+	Model          string    `json:"model"`
+	ResponseFormat *struct {
+		Type string `json:"type"`
+	} `json:"response_format,omitempty"`
+}
+
+type Client struct {
+	*http.Client
+	Token   string
+	BaseURL string
+}
+
+type apiError struct {
+	Message struct {
+		Detail []struct {
+			Msg string
+		}
+	}
+	Type string
+}
+
+func (e apiError) Error() string {
+	messages := make([]string, len(e.Message.Detail))
+	for i := range e.Message.Detail {
+		messages[i] = e.Message.Detail[i].Msg
+	}
+	return fmt.Sprintf("%s: %s", e.Type, strings.Join(messages, ", "))
+}
+
+func (c *Client) do(req *http.Request) (*http.Response, error) {
+	if c.Client == nil {
+		c.Client = http.DefaultClient
+	}
+	if c.Token != "" {
+		req.Header.Set("Authorization", "Bearer "+c.Token)
+	}
+	req.Header.Set("Accept", "application/json")
+	if req.Body != nil {
+		req.Header.Set("Content-Type", "application/json")
+	}
+	return c.Do(req)
+}
+
+type completeResponse struct {
+	Choices []struct {
+		Message Message
+	}
+}
+
+func (c *Client) Complete(chat *Chat) (*Message, error) {
+	b, err := json.Marshal(chat)
+	if err != nil {
+		return nil, fmt.Errorf("encode messages: %w", err)
+	}
+	u := c.BaseURL + "/v1/chat/completions"
+	req, err := http.NewRequest(http.MethodPost, u, bytes.NewReader(b))
+	if err != nil {
+		return nil, err
+	}
+	resp, err := c.do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode >= 400 && resp.StatusCode <= 499 {
+		var aerr apiError
+		if err := json.NewDecoder(resp.Body).Decode(&aerr); err != nil {
+			return nil, fmt.Errorf(resp.Status)
+		}
+		return nil, aerr
+	} else if resp.StatusCode >= 500 {
+		return nil, fmt.Errorf(resp.Status)
+	}
+
+	var cresp completeResponse
+	if err := json.NewDecoder(resp.Body).Decode(&cresp); err != nil {
+		return nil, fmt.Errorf("decode response: %w", err)
+	}
+	if len(cresp.Choices) == 0 {
+		return nil, fmt.Errorf("no completions in response")
+	}
+	return &cresp.Choices[0].Message, nil
+}