1
package main
1
package main
2
2
3
// This binary runs one LLM inference on one input.
3
// This binary runs one LLM inference on one input.
4
4
5
import "encoding/base64"
5
import "encoding/base64"
6
import "flag"
6
import "flag"
7
import "fmt"
7
import "fmt"
8
import "encoding/json"
8
import "encoding/json"
9
import "io/ioutil"
9
import "io/ioutil"
10
import "log"
10
import "log"
11
import ""
11
import ""
12
import "os"
12
import "os"
13
import "path/filepath"
14
import "strings"
13
import "strings"
15
14
16
import "oscarkilo.com/klex-git/api"
15
import "oscarkilo.com/klex-git/api"
17
import "oscarkilo.com/klex-git/config"
16
import "oscarkilo.com/klex-git/config"
18
17
19
var model = flag.String("model", "", "overrides .Model, if non-empty")
18
var model = flag.String("model", "", "overrides .Model, if non-empty")
20
var system = flag.String("system_file", "", "overrides .System, if non-empty")
19
var system = flag.String("system_file", "", "overrides .System, if non-empty")
21
var prompt = flag.String("prompt_file", "", "appends to .Messages")
20
var prompt = flag.String("prompt_file", "", "appends to .Messages")
22
var image = flag.String("image_file", "", "attaches an image to the prompt")
21
var image = flag.String("image_file", "", "attaches an image to the prompt")
23
var format = flag.String("format", "text", "text|json|jsonindent")
22
var format = flag.String("format", "text", "text|json|jsonindent")
24
23
25
// guessMimeType returns the MIME type inferred from file t.
24
// guessMimeType returns the MIME type inferred from file t.
26
func guessMimeType( ) string {
25
func guessMimeType( ) string {
27
le()
26
le()
28
=
27
=
29
if mimeType == "" {
30
mimeType = "application/octet-stream" // fallback MIME type
31
}
28
}
32
return eType
29
return eType
33
}
30
}
34
31
35
func main() {
32
func main() {
36
flag.Parse()
33
flag.Parse()
37
34
38
// Find the API keys and configure a Klex client.
35
// Find the API keys and configure a Klex client.
39
config, err := config.ReadConfig()
36
config, err := config.ReadConfig()
40
if err != nil {
37
if err != nil {
41
log.Fatalf("Failed to read config: %v", err)
38
log.Fatalf("Failed to read config: %v", err)
42
}
39
}
43
client := api.NewClient(config.KlexUrl, config.ApiKey)
40
client := api.NewClient(config.KlexUrl, config.ApiKey)
44
if client == nil {
41
if client == nil {
45
log.Fatalf("Failed to create Klex client")
42
log.Fatalf("Failed to create Klex client")
46
}
43
}
47
44
48
// Parse stdin as a MessagesRequest object, allowing empty input.
45
// Parse stdin as a MessagesRequest object, allowing empty input.
49
sin, err := ioutil.ReadAll(os.Stdin)
46
sin, err := ioutil.ReadAll(os.Stdin)
50
if err != nil {
47
if err != nil {
51
log.Fatalf("Failed to read stdin: %v", err)
48
log.Fatalf("Failed to read stdin: %v", err)
52
}
49
}
53
if len(sin) == 0 {
50
if len(sin) == 0 {
54
sin = []byte("{}")
51
sin = []byte("{}")
55
}
52
}
56
var req api.MessagesRequest
53
var req api.MessagesRequest
57
err = json.Unmarshal(sin, &req)
54
err = json.Unmarshal(sin, &req)
58
if err != nil {
55
if err != nil {
59
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
56
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
60
}
57
}
61
58
62
// Use flags to override parts of the request.
59
// Use flags to override parts of the request.
63
if *model != "" {
60
if *model != "" {
64
req.Model = *model
61
req.Model = *model
65
}
62
}
66
if *system != "" {
63
if *system != "" {
67
s, err := ioutil.ReadFile(*system)
64
s, err := ioutil.ReadFile(*system)
68
if err != nil {
65
if err != nil {
69
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
66
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
70
}
67
}
71
req.System = string(s)
68
req.System = string(s)
72
}
69
}
73
if *image != "" && *prompt == "" {
70
if *image != "" && *prompt == "" {
74
log.Fatalf("--image_file requires a non-empty --prompt_file, too")
71
log.Fatalf("--image_file requires a non-empty --prompt_file, too")
75
}
72
}
76
if *prompt != "" {
73
if *prompt != "" {
77
msg := api.ChatMessage{Role: "user"}
74
msg := api.ChatMessage{Role: "user"}
78
if *image != "" {
75
if *image != "" {
79
mime_type := guessMimeType(*image)
80
switch mime_type {
81
case "image/jpeg", "image/png", "image/gif", "image/webp":
82
default:
83
log.Fatal("Unsupported image type: %s", mime_type)
84
}
85
i, err := ioutil.ReadFile(*image)
76
i, err := ioutil.ReadFile(*image)
86
if err != nil {
77
if err != nil {
87
log.Fatalf("Failed to read --image_file %s: %v", *image, err)
78
log.Fatalf("Failed to read --image_file %s: %v", *image, err)
88
}
79
}
80
mime_type := guessMimeType(i)
81
switch mime_type {
82
case "image/jpeg", "image/png", "image/gif", "image/webp":
83
default:
84
log.Fatalf("Unsupported image type: %s", mime_type)
85
}
89
msg.Content = append(msg.Content, api.ContentBlock{
86
msg.Content = append(msg.Content, api.ContentBlock{
90
Type: "image",
87
Type: "image",
91
Source: &api.ContentSource{
88
Source: &api.ContentSource{
92
Type: "base64",
89
Type: "base64",
93
MediaType: mime_type,
90
MediaType: mime_type,
94
Data: base64.StdEncoding.EncodeToString(i),
91
Data: base64.StdEncoding.EncodeToString(i),
95
},
92
},
96
})
93
})
97
}
94
}
98
p, err := ioutil.ReadFile(*prompt)
95
p, err := ioutil.ReadFile(*prompt)
99
if err != nil {
96
if err != nil {
100
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
97
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
101
}
98
}
102
msg.Content = append(msg.Content, api.ContentBlock{
99
msg.Content = append(msg.Content, api.ContentBlock{
103
Type: "text",
100
Type: "text",
104
Text: string(p),
101
Text: string(p),
105
})
102
})
106
}
103
}
107
104
108
// Get LLM output from Klex.
105
// Get LLM output from Klex.
109
res, err := client.Messages(req)
106
res, err := client.Messages(req)
110
if err != nil {
107
if err != nil {
111
log.Fatalf("Klex f() failure: %v", err)
108
log.Fatalf("Klex f() failure: %v", err)
112
}
109
}
113
110
114
// Print according to the --format flag.
111
// Print according to the --format flag.
115
out, err := formatResponse(res)
112
out, err := formatResponse(res)
116
if err != nil {
113
if err != nil {
117
log.Fatalf("Failed to format response: %v", err)
114
log.Fatalf("Failed to format response: %v", err)
118
}
115
}
119
fmt.Print(out)
116
fmt.Print(out)
120
}
117
}
121
118
122
func formatResponse(res *api.MessagesResponse) (string, error) {
119
func formatResponse(res *api.MessagesResponse) (string, error) {
123
switch *format {
120
switch *format {
124
case "text":
121
case "text":
125
var content []string
122
var content []string
126
for _, c := range res.Content {
123
for _, c := range res.Content {
127
if c.Type == "text" {
124
if c.Type == "text" {
128
content = append(content, c.Text + "\n")
125
content = append(content, c.Text + "\n")
129
}
126
}
130
}
127
}
131
return strings.Join(content, "\n"), nil
128
return strings.Join(content, "\n"), nil
132
case "json":
129
case "json":
133
buf, err := json.Marshal(res)
130
buf, err := json.Marshal(res)
134
return string(buf), err
131
return string(buf), err
135
case "jsonindent":
132
case "jsonindent":
136
buf, err := json.MarshalIndent(res, "", " ")
133
buf, err := json.MarshalIndent(res, "", " ")
137
return string(buf), err
134
return string(buf), err
138
default:
135
default:
139
return "", fmt.Errorf("Unsupported --format=%s", *format)
136
return "", fmt.Errorf("Unsupported --format=%s", *format)
140
}
137
}
141
}
138
}