1
package main
2
3
// This binary runs one LLM inference on one input.
4
5
import "flag"
6
import "fmt"
7
import "encoding/json"
8
import "io/ioutil"
9
import "log"
10
import "os"
11
import "strings"
12
13
import "oscarkilo.com/klex-git/api"
14
import "oscarkilo.com/klex-git/config"
15
16
var model = flag.String("model", "", "overrides .Model, if non-empty")
17
var system = flag.String("system_file", "", "overrides .System, if non-empty")
18
var prompt = flag.String("prompt_file", "", "appends to .Messages")
19
var format = flag.String("format", "text", "text|json|jsonindent")
20
21
func main() {
22
config, err := config.ReadConfig()
23
if err != nil {
24
log.Fatalf("Failed to read config: %v", err)
25
}
26
client := api.NewClient(config.KlexUrl, config.ApiKey)
27
if client == nil {
28
log.Fatalf("Failed to create Klex client")
29
}
30
31
var req api.MessagesRequest
32
err = json.NewDecoder(os.Stdin).Decode(&req)
33
if err != nil {
34
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
35
}
36
37
if *model != "" {
38
req.Model = *model
39
}
40
if *system != "" {
41
s, err := ioutil.ReadFile(*system)
42
if err != nil {
43
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
44
}
45
req.System = string(s)
46
}
47
if *prompt != "" {
48
p, err := ioutil.ReadFile(*prompt)
49
if err != nil {
50
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
51
}
52
req.Messages = append(req.Messages, api.ChatMessage{
53
Role: "user",
54
Content: []api.ContentBlock{{
55
Type: "text",
56
Text: string(p),
57
}},
58
})
59
}
60
61
res, err := client.Messages(req)
62
if err != nil {
63
log.Fatalf("Klex f() failure: %v", err)
64
}
65
66
out, err := formatResponse(res)
67
if err != nil {
68
log.Fatalf("Failed to format response: %v", err)
69
}
70
fmt.Print(out)
71
}
72
73
func formatResponse(res *api.MessagesResponse) (string, error) {
74
switch *format {
75
case "text":
76
var content []string
77
for _, c := range res.Content {
78
if c.Type == "text" {
79
content = append(content, c.Text + "\n")
80
}
81
}
82
return strings.Join(content, "\n"), nil
83
case "json":
84
buf, err := json.Marshal(res)
85
return string(buf), err
86
case "jsonindent":
87
buf, err := json.MarshalIndent(res, "", " ")
88
return string(buf), err
89
default:
90
return "", fmt.Errorf("Unsupported --format=%s", *format)
91
}
92
}