1
package main
1
package main
2
2
3
// This binary runs one LLM inference on one input.
3
// This binary runs one LLM inference on one input.
4
4
5
import "flag"
5
import "flag"
6
import "fmt"
6
import "fmt"
7
import "encoding/json"
7
import "encoding/json"
8
import "io/ioutil"
8
import "io/ioutil"
9
import "log"
9
import "log"
10
import "os"
10
import "os"
11
import "strings"
11
import "strings"
12
12
13
import "oscarkilo.com/klex-git/api"
13
import "oscarkilo.com/klex-git/api"
14
import "oscarkilo.com/klex-git/config"
14
import "oscarkilo.com/klex-git/config"
15
15
16
var model = flag.String("model", "", "overrides .Model, if non-empty")
16
var model = flag.String("model", "", "overrides .Model, if non-empty")
17
var system = flag.String("system_file", "", "overrides .System, if non-empty")
17
var system = flag.String("system_file", "", "overrides .System, if non-empty")
18
var prompt = flag.String("prompt_file", "", "appends to .Messages")
18
var prompt = flag.String("prompt_file", "", "appends to .Messages")
19
var format = flag.String("format", "text", "text|json|jsonindent")
19
var format = flag.String("format", "text", "text|json|jsonindent")
20
20
21
func main() {
21
func main() {
22
flag.Parse()
23
22
// Find the API keys and configure a Klex client.
24
// Find the API keys and configure a Klex client.
23
config, err := config.ReadConfig()
25
config, err := config.ReadConfig()
24
if err != nil {
26
if err != nil {
25
log.Fatalf("Failed to read config: %v", err)
27
log.Fatalf("Failed to read config: %v", err)
26
}
28
}
27
client := api.NewClient(config.KlexUrl, config.ApiKey)
29
client := api.NewClient(config.KlexUrl, config.ApiKey)
28
if client == nil {
30
if client == nil {
29
log.Fatalf("Failed to create Klex client")
31
log.Fatalf("Failed to create Klex client")
30
}
32
}
31
33
32
// Parse stdin as a MessagesRequest object, allowing empty input.
34
// Parse stdin as a MessagesRequest object, allowing empty input.
33
sin, err := ioutil.ReadAll(os.Stdin)
35
sin, err := ioutil.ReadAll(os.Stdin)
34
if err != nil {
36
if err != nil {
35
log.Fatalf("Failed to read stdin: %v", err)
37
log.Fatalf("Failed to read stdin: %v", err)
36
}
38
}
37
if len(sin) == 0 {
39
if len(sin) == 0 {
38
sin = []byte("{}")
40
sin = []byte("{}")
39
}
41
}
40
var req api.MessagesRequest
42
var req api.MessagesRequest
41
err = json.Unmarshal(sin, &req)
43
err = json.Unmarshal(sin, &req)
42
if err != nil {
44
if err != nil {
43
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
45
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
44
}
46
}
45
47
46
// Use flags to override parts of the request.
48
// Use flags to override parts of the request.
47
if *model != "" {
49
if *model != "" {
48
req.Model = *model
50
req.Model = *model
49
}
51
}
50
if *system != "" {
52
if *system != "" {
51
s, err := ioutil.ReadFile(*system)
53
s, err := ioutil.ReadFile(*system)
52
if err != nil {
54
if err != nil {
53
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
55
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
54
}
56
}
55
req.System = string(s)
57
req.System = string(s)
56
}
58
}
57
if *prompt != "" {
59
if *prompt != "" {
58
p, err := ioutil.ReadFile(*prompt)
60
p, err := ioutil.ReadFile(*prompt)
59
if err != nil {
61
if err != nil {
60
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
62
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
61
}
63
}
62
req.Messages = append(req.Messages, api.ChatMessage{
64
req.Messages = append(req.Messages, api.ChatMessage{
63
Role: "user",
65
Role: "user",
64
Content: []api.ContentBlock{{
66
Content: []api.ContentBlock{{
65
Type: "text",
67
Type: "text",
66
Text: string(p),
68
Text: string(p),
67
}},
69
}},
68
})
70
})
69
}
71
}
70
72
71
// Get LLM output from Klex.
73
// Get LLM output from Klex.
72
res, err := client.Messages(req)
74
res, err := client.Messages(req)
73
if err != nil {
75
if err != nil {
74
log.Fatalf("Klex f() failure: %v", err)
76
log.Fatalf("Klex f() failure: %v", err)
75
}
77
}
76
78
77
// Print according to the --format flag.
79
// Print according to the --format flag.
78
out, err := formatResponse(res)
80
out, err := formatResponse(res)
79
if err != nil {
81
if err != nil {
80
log.Fatalf("Failed to format response: %v", err)
82
log.Fatalf("Failed to format response: %v", err)
81
}
83
}
82
fmt.Print(out)
84
fmt.Print(out)
83
}
85
}
84
86
85
func formatResponse(res *api.MessagesResponse) (string, error) {
87
func formatResponse(res *api.MessagesResponse) (string, error) {
86
switch *format {
88
switch *format {
87
case "text":
89
case "text":
88
var content []string
90
var content []string
89
for _, c := range res.Content {
91
for _, c := range res.Content {
90
if c.Type == "text" {
92
if c.Type == "text" {
91
content = append(content, c.Text + "\n")
93
content = append(content, c.Text + "\n")
92
}
94
}
93
}
95
}
94
return strings.Join(content, "\n"), nil
96
return strings.Join(content, "\n"), nil
95
case "json":
97
case "json":
96
buf, err := json.Marshal(res)
98
buf, err := json.Marshal(res)
97
return string(buf), err
99
return string(buf), err
98
case "jsonindent":
100
case "jsonindent":
99
buf, err := json.MarshalIndent(res, "", " ")
101
buf, err := json.MarshalIndent(res, "", " ")
100
return string(buf), err
102
return string(buf), err
101
default:
103
default:
102
return "", fmt.Errorf("Unsupported --format=%s", *format)
104
return "", fmt.Errorf("Unsupported --format=%s", *format)
103
}
105
}
104
}
106
}