1
package main
1
package main
2
2
3
// This binary runs one LLM inference on one input.
3
// This binary runs one LLM inference on one input.
4
4
5
import "encoding/base64"
5
import "encoding/base64"
6
import "flag"
6
import "flag"
7
import "fmt"
7
import "fmt"
8
import "encoding/json"
8
import "encoding/json"
9
import "io/ioutil"
9
import "io/ioutil"
10
import "log"
10
import "log"
11
import "net/http"
11
import "net/http"
12
import "os"
12
import "os"
13
import "strings"
13
import "strings"
14
14
15
import "oscarkilo.com/klex-git/api"
15
import "oscarkilo.com/klex-git/api"
16
import "oscarkilo.com/klex-git/config"
16
import "oscarkilo.com/klex-git/config"
17
17
18
var model = flag.String("model", "", "overrides .Model, if non-empty")
18
var model = flag.String("model", "", "overrides .Model, if non-empty")
19
var system = flag.String("system_file", "", "overrides .System, if non-empty")
19
var system = flag.String("system_file", "", "overrides .System, if non-empty")
20
var prompt = flag.String("prompt_file", "", "appends to .Messages")
20
var prompt = flag.String("prompt_file", "", "appends to .Messages")
21
var image = flag.String("image_file", "", "attaches an image to the prompt")
21
var image = flag.String("image_file", "", "attaches an image to the prompt")
22
var format = flag.String("format", "text", "text|json|jsonindent")
22
var format = flag.String("format", "text", "text|json|jsonindent")
23
23
24
// guessMimeType returns the MIME type inferred from file contents.
24
// guessMimeType returns the MIME type inferred from file contents.
25
func guessMimeType(b []byte) string {
25
func guessMimeType(b []byte) string {
26
if len(b) > 512 {
26
if len(b) > 512 {
27
b = b[:512]
27
b = b[:512]
28
}
28
}
29
return http.DetectContentType(b)
29
return http.DetectContentType(b)
30
}
30
}
31
31
32
func main() {
32
func main() {
33
flag.Parse()
33
flag.Parse()
34
34
35
// Find the API keys and configure a Klex client.
35
// Find the API keys and configure a Klex client.
36
config, err := config.ReadConfig()
36
config, err := config.ReadConfig()
37
if err != nil {
37
if err != nil {
38
log.Fatalf("Failed to read config: %v", err)
38
log.Fatalf("Failed to read config: %v", err)
39
}
39
}
40
client := api.NewClient(config.KlexUrl, config.ApiKey)
40
client := api.NewClient(config.KlexUrl, config.ApiKey)
41
if client == nil {
41
if client == nil {
42
log.Fatalf("Failed to create Klex client")
42
log.Fatalf("Failed to create Klex client")
43
}
43
}
44
44
45
// Parse stdin as a MessagesRequest object, allowing empty input.
45
// Parse stdin as a MessagesRequest object, allowing empty input.
46
sin, err := ioutil.ReadAll(os.Stdin)
46
sin, err := ioutil.ReadAll(os.Stdin)
47
if err != nil {
47
if err != nil {
48
log.Fatalf("Failed to read stdin: %v", err)
48
log.Fatalf("Failed to read stdin: %v", err)
49
}
49
}
50
if len(sin) == 0 {
50
if len(sin) == 0 {
51
sin = []byte("{}")
51
sin = []byte("{}")
52
}
52
}
53
var req api.MessagesRequest
53
var req api.MessagesRequest
54
err = json.Unmarshal(sin, &req)
54
err = json.Unmarshal(sin, &req)
55
if err != nil {
55
if err != nil {
56
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
56
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
57
}
57
}
58
58
59
// Use flags to override parts of the request.
59
// Use flags to override parts of the request.
60
if *model != "" {
60
if *model != "" {
61
req.Model = *model
61
req.Model = *model
62
}
62
}
63
if *system != "" {
63
if *system != "" {
64
s, err := ioutil.ReadFile(*system)
64
s, err := ioutil.ReadFile(*system)
65
if err != nil {
65
if err != nil {
66
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
66
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
67
}
67
}
68
req.System = string(s)
68
req.System = string(s)
69
}
69
}
70
if *image != "" && *prompt == "" {
70
if *image != "" && *prompt == "" {
71
log.Fatalf("--image_file requires a non-empty --prompt_file, too")
71
log.Fatalf("--image_file requires a non-empty --prompt_file, too")
72
}
72
}
73
if *prompt != "" {
73
if *prompt != "" {
74
msg := api.ChatMessage{Role: "user"}
74
msg := api.ChatMessage{Role: "user"}
75
if *image != "" {
75
if *image != "" {
76
i, err := ioutil.ReadFile(*image)
76
i, err := ioutil.ReadFile(*image)
77
if err != nil {
77
if err != nil {
78
log.Fatalf("Failed to read --image_file %s: %v", *image, err)
78
log.Fatalf("Failed to read --image_file %s: %v", *image, err)
79
}
79
}
80
mime_type := guessMimeType(i)
80
mime_type := guessMimeType(i)
81
switch mime_type {
81
switch mime_type {
82
case "image/jpeg", "image/png", "image/gif", "image/webp":
82
case "image/jpeg", "image/png", "image/gif", "image/webp":
83
default:
83
default:
84
log.Fatalf("Unsupported image type: %s", mime_type)
84
log.Fatalf("Unsupported image type: %s", mime_type)
85
}
85
}
86
msg.Content = append(msg.Content, api.ContentBlock{
86
msg.Content = append(msg.Content, api.ContentBlock{
87
Type: "image",
87
Type: "image",
88
Source: &api.ContentSource{
88
Source: &api.ContentSource{
89
Type: "base64",
89
Type: "base64",
90
MediaType: mime_type,
90
MediaType: mime_type,
91
Data: base64.StdEncoding.EncodeToString(i),
91
Data: base64.StdEncoding.EncodeToString(i),
92
},
92
},
93
})
93
})
94
}
94
}
95
p, err := ioutil.ReadFile(*prompt)
95
p, err := ioutil.ReadFile(*prompt)
96
if err != nil {
96
if err != nil {
97
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
97
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
98
}
98
}
99
msg.Content = append(msg.Content, api.ContentBlock{
99
msg.Content = append(msg.Content, api.ContentBlock{
100
Type: "text",
100
Type: "text",
101
Text: string(p),
101
Text: string(p),
102
})
102
})
103
req.Messages = append(req.Messages, msg)
103
}
104
}
104
105
105
// Get LLM output from Klex.
106
// Get LLM output from Klex.
106
res, err := client.Messages(req)
107
res, err := client.Messages(req)
107
if err != nil {
108
if err != nil {
108
log.Fatalf("Klex f() failure: %v", err)
109
log.Fatalf("Klex f() failure: %v", err)
109
}
110
}
110
111
111
// Print according to the --format flag.
112
// Print according to the --format flag.
112
out, err := formatResponse(res)
113
out, err := formatResponse(res)
113
if err != nil {
114
if err != nil {
114
log.Fatalf("Failed to format response: %v", err)
115
log.Fatalf("Failed to format response: %v", err)
115
}
116
}
116
fmt.Print(out)
117
fmt.Print(out)
117
}
118
}
118
119
119
func formatResponse(res *api.MessagesResponse) (string, error) {
120
func formatResponse(res *api.MessagesResponse) (string, error) {
120
switch *format {
121
switch *format {
121
case "text":
122
case "text":
122
var content []string
123
var content []string
123
for _, c := range res.Content {
124
for _, c := range res.Content {
124
if c.Type == "text" {
125
if c.Type == "text" {
125
content = append(content, c.Text + "\n")
126
content = append(content, c.Text + "\n")
126
}
127
}
127
}
128
}
128
return strings.Join(content, "\n"), nil
129
return strings.Join(content, "\n"), nil
129
case "json":
130
case "json":
130
buf, err := json.Marshal(res)
131
buf, err := json.Marshal(res)
131
return string(buf), err
132
return string(buf), err
132
case "jsonindent":
133
case "jsonindent":
133
buf, err := json.MarshalIndent(res, "", " ")
134
buf, err := json.MarshalIndent(res, "", " ")
134
return string(buf), err
135
return string(buf), err
135
default:
136
default:
136
return "", fmt.Errorf("Unsupported --format=%s", *format)
137
return "", fmt.Errorf("Unsupported --format=%s", *format)
137
}
138
}
138
}
139
}