commit - 74566e714d2448293a6c527f8cb2c0c7abafabf9
commit + 7d5f9562c75185f9fe13bbc326577dc860146124
blob - /dev/null
blob + 66acf4ecc4c76473fd24bbd22b6b548511520c4e (mode 644)
--- /dev/null
+++ src/llm/llm.go
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "path"
+
+ "olowe.co/x/openai"
+)
+
+var model = flag.String("m", "ministral-8b-latest", "model")
+var baseURL = flag.String("u", "https://api.mistral.ai", "openai API base URL")
+var sysPrompt = flag.String("s", "You are a helpful assistant.", "system prompt")
+var converse = flag.Bool("c", false, "start a back-and-forth chat")
+
+func readToken() (string, error) {
+ confDir, err := os.UserConfigDir()
+ if err != nil {
+ return "", err
+ }
+ b, err := os.ReadFile(path.Join(confDir, "openai/token"))
+ return string(bytes.TrimSpace(b)), err
+}
+
+func copyAll(w io.Writer, paths []string) (n int64, err error) {
+ if len(paths) == 0 {
+ return io.Copy(w, os.Stdin)
+ }
+ var errs []error
+ for _, name := range paths {
+ f, err := os.Open(name)
+ if err != nil {
+ return n, err
+ }
+ nn, err := io.Copy(w, f)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("copy %s: %w", name, err))
+ }
+ n += nn
+ }
+ return n, errors.Join(errs...)
+}
+
+func init() {
+ log.SetFlags(0)
+ log.SetPrefix("llm: ")
+ flag.Parse()
+}
+
+func main() {
+ token, err := readToken()
+ if err != nil {
+ log.Fatalf("read auth token: %v", err)
+ }
+ client := &openai.Client{http.DefaultClient, token, *baseURL}
+
+ chat := openai.Chat{
+ Messages: []openai.Message{
+ {openai.RoleSystem, *sysPrompt},
+ },
+ Model: *model,
+ }
+ buf := &bytes.Buffer{}
+ if !*converse {
+ _, err := copyAll(buf, flag.Args())
+ if err != nil {
+ log.Fatalln("construct prompt:", err)
+ }
+ msg := openai.Message{openai.RoleUser, buf.String()}
+ chat.Messages = append(chat.Messages, msg)
+ reply, err := client.Complete(&chat)
+ if err != nil {
+ log.Fatalln("llm complete:", err)
+ }
+ fmt.Println(reply.Content)
+ return
+ }
+
+ sc := bufio.NewScanner(os.Stdin)
+ if len(flag.Args()) > 0 {
+ log.Println("conversation mode, ignoring arguments")
+ }
+ for sc.Scan() {
+ if sc.Text() == "." {
+ msg := openai.Message{openai.RoleUser, buf.String()}
+ chat.Messages = append(chat.Messages, msg)
+ reply, err := client.Complete(&chat)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "chat not completed:", err)
+ continue // try again, allowing a retry with another "." line
+ }
+ buf.Reset()
+ fmt.Println(reply.Content)
+ chat.Messages = append(chat.Messages, *reply)
+ continue
+ }
+ fmt.Fprintln(buf, sc.Text())
+ }
+}
blob - /dev/null
blob + 50086755e5e95a4d364613b0e3e8358e87751979 (mode 644)
--- /dev/null
+++ src/openai/openai.go
+package openai
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+)
+
+type Role string
+
+const (
+ RoleSystem Role = "system"
+ RoleUser Role = "user"
+ RoleAssistant = "assistant"
+)
+
+type Message struct {
+ Role Role `json:"role"`
+ Content string `json:"content"`
+}
+
+type Chat struct {
+ Messages []Message `json:"messages"`
+ Model string `json:"model"`
+ ResponseFormat *struct {
+ Type string `json:"type"`
+ } `json:"response_format,omitempty"`
+}
+
+type Client struct {
+ *http.Client
+ Token string
+ BaseURL string
+}
+
+type apiError struct {
+ Message struct {
+ Detail []struct {
+ Msg string
+ }
+ }
+ Type string
+}
+
+func (e apiError) Error() string {
+ messages := make([]string, len(e.Message.Detail))
+ for i := range e.Message.Detail {
+ messages[i] = e.Message.Detail[i].Msg
+ }
+ return fmt.Sprintf("%s: %s", e.Type, strings.Join(messages, ", "))
+}
+
+func (c *Client) do(req *http.Request) (*http.Response, error) {
+ if c.Client == nil {
+ c.Client = http.DefaultClient
+ }
+ if c.Token != "" {
+ req.Header.Set("Authorization", "Bearer "+c.Token)
+ }
+ req.Header.Set("Accept", "application/json")
+ if req.Body != nil {
+ req.Header.Set("Content-Type", "application/json")
+ }
+ return c.Do(req)
+}
+
+type completeResponse struct {
+ Choices []struct {
+ Message Message
+ }
+}
+
+func (c *Client) Complete(chat *Chat) (*Message, error) {
+ b, err := json.Marshal(chat)
+ if err != nil {
+ return nil, fmt.Errorf("encode messages: %w", err)
+ }
+ u := c.BaseURL + "/v1/chat/completions"
+ req, err := http.NewRequest(http.MethodPost, u, bytes.NewReader(b))
+ if err != nil {
+ return nil, err
+ }
+ resp, err := c.do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode >= 400 && resp.StatusCode <= 499 {
+ var aerr apiError
+ if err := json.NewDecoder(resp.Body).Decode(&aerr); err != nil {
+ return nil, fmt.Errorf(resp.Status)
+ }
+ return nil, aerr
+ } else if resp.StatusCode >= 500 {
+ return nil, fmt.Errorf(resp.Status)
+ }
+
+ var cresp completeResponse
+ if err := json.NewDecoder(resp.Body).Decode(&cresp); err != nil {
+ return nil, fmt.Errorf("decode response: %w", err)
+ }
+ if len(cresp.Choices) == 0 {
+ return nil, fmt.Errorf("no completions in response")
+ }
+ return &cresp.Choices[0].Message, nil
+}