commit 7d5f9562c75185f9fe13bbc326577dc860146124 from: Oliver Lowe date: Fri Jan 10 01:34:05 2025 UTC let the slop flow commit - 74566e714d2448293a6c527f8cb2c0c7abafabf9 commit + 7d5f9562c75185f9fe13bbc326577dc860146124 blob - /dev/null blob + 66acf4ecc4c76473fd24bbd22b6b548511520c4e (mode 644) --- /dev/null +++ src/llm/llm.go @@ -0,0 +1,106 @@ +package main + +import ( + "bufio" + "bytes" + "errors" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "path" + + "olowe.co/x/openai" +) + +var model = flag.String("m", "ministral-8b-latest", "model") +var baseURL = flag.String("u", "https://api.mistral.ai", "openai API base URL") +var sysPrompt = flag.String("s", "You are a helpful assistant.", "system prompt") +var converse = flag.Bool("c", false, "start a back-and-forth chat") + +func readToken() (string, error) { + confDir, err := os.UserConfigDir() + if err != nil { + return "", err + } + b, err := os.ReadFile(path.Join(confDir, "openai/token")) + return string(bytes.TrimSpace(b)), err +} + +func copyAll(w io.Writer, paths []string) (n int64, err error) { + if len(paths) == 0 { + return io.Copy(w, os.Stdin) + } + var errs []error + for _, name := range paths { + f, err := os.Open(name) + if err != nil { + return n, err + } + nn, err := io.Copy(w, f) + if err != nil { + errs = append(errs, fmt.Errorf("copy %s: %w", name, err)) + } + n += nn + } + return n, errors.Join(errs...) +} + +func init() { + log.SetFlags(0) + log.SetPrefix("llm: ") + flag.Parse() +} + +func main() { + token, err := readToken() + if err != nil { + log.Fatalf("read auth token: %v", err) + } + client := &openai.Client{http.DefaultClient, token, *baseURL} + + chat := openai.Chat{ + Messages: []openai.Message{ + {openai.RoleSystem, *sysPrompt}, + }, + Model: *model, + } + buf := &bytes.Buffer{} + if !*converse { + _, err := copyAll(buf, flag.Args()) + if err != nil { + log.Fatalln("construct prompt:", err) + } + msg := openai.Message{openai.RoleUser, buf.String()} + chat.Messages = append(chat.Messages, msg) + reply, err := client.Complete(&chat) + if err != nil { + log.Fatalln("llm complete:", err) + } + fmt.Println(reply.Content) + return + } + + sc := bufio.NewScanner(os.Stdin) + if len(flag.Args()) > 0 { + log.Println("conversation mode, ignoring arguments") + } + for sc.Scan() { + if sc.Text() == "." { + msg := openai.Message{openai.RoleUser, buf.String()} + chat.Messages = append(chat.Messages, msg) + reply, err := client.Complete(&chat) + if err != nil { + fmt.Fprintln(os.Stderr, "chat not completed:", err) + continue // try again, allowing a retry with another "." line + } + buf.Reset() + fmt.Println(reply.Content) + chat.Messages = append(chat.Messages, *reply) + continue + } + fmt.Fprintln(buf, sc.Text()) + } +} blob - /dev/null blob + 50086755e5e95a4d364613b0e3e8358e87751979 (mode 644) --- /dev/null +++ src/openai/openai.go @@ -0,0 +1,108 @@ +package openai + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +type Role string + +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant = "assistant" +) + +type Message struct { + Role Role `json:"role"` + Content string `json:"content"` +} + +type Chat struct { + Messages []Message `json:"messages"` + Model string `json:"model"` + ResponseFormat *struct { + Type string `json:"type"` + } `json:"response_format,omitempty"` +} + +type Client struct { + *http.Client + Token string + BaseURL string +} + +type apiError struct { + Message struct { + Detail []struct { + Msg string + } + } + Type string +} + +func (e apiError) Error() string { + messages := make([]string, len(e.Message.Detail)) + for i := range e.Message.Detail { + messages[i] = e.Message.Detail[i].Msg + } + return fmt.Sprintf("%s: %s", e.Type, strings.Join(messages, ", ")) +} + +func (c *Client) do(req *http.Request) (*http.Response, error) { + if c.Client == nil { + c.Client = http.DefaultClient + } + if c.Token != "" { + req.Header.Set("Authorization", "Bearer "+c.Token) + } + req.Header.Set("Accept", "application/json") + if req.Body != nil { + req.Header.Set("Content-Type", "application/json") + } + return c.Do(req) +} + +type completeResponse struct { + Choices []struct { + Message Message + } +} + +func (c *Client) Complete(chat *Chat) (*Message, error) { + b, err := json.Marshal(chat) + if err != nil { + return nil, fmt.Errorf("encode messages: %w", err) + } + u := c.BaseURL + "/v1/chat/completions" + req, err := http.NewRequest(http.MethodPost, u, bytes.NewReader(b)) + if err != nil { + return nil, err + } + resp, err := c.do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 && resp.StatusCode <= 499 { + var aerr apiError + if err := json.NewDecoder(resp.Body).Decode(&aerr); err != nil { + return nil, fmt.Errorf(resp.Status) + } + return nil, aerr + } else if resp.StatusCode >= 500 { + return nil, fmt.Errorf(resp.Status) + } + + var cresp completeResponse + if err := json.NewDecoder(resp.Body).Decode(&cresp); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + if len(cresp.Choices) == 0 { + return nil, fmt.Errorf("no completions in response") + } + return &cresp.Choices[0].Message, nil +}