1
package api
1
package api
2
2
3
import "encoding/base64"
4
import "net/http"
5
6
// NewDocumentBlock wraps raw file bytes as a Type:"document"
7
// ContentBlock suitable for appending to a ChatMessage. The
8
// MIME type is inferred from the content via
9
// http.DetectContentType; the data is base64-encoded. Use this
10
// instead of constructing the ContentBlock manually so the
11
// wire-format details (Source.Type:"base64", base64-encoding,
12
// MIME detection) live in one place.
13
func NewDocumentBlock(data []byte) ContentBlock {
14
return ContentBlock{
15
Type: "document",
16
Source: &ContentSource{
17
Type: "base64",
18
MediaType: http.DetectContentType(data),
19
Data: base64.StdEncoding.EncodeToString(data),
20
},
21
}
22
}
23
3
type ChatMessage struct {
24
type ChatMessage struct {
4
Role string `json:"role"`
25
Role string `json:"role"`
5
Content []ContentBlock `json:"content"`
26
Content []ContentBlock `json:"content"`
6
}
27
}
7
28
8
// MessageRequest is generalizes across OpenAI, Anthropic, etc.
29
// MessageRequest is generalizes across OpenAI, Anthropic, etc.
9
type MessagesRequest struct {
30
type MessagesRequest struct {
10
Model string `json:"model"`
31
Model string `json:"model"`
11
Messages []ChatMessage `json:"messages"`
32
Messages []ChatMessage `json:"messages"`
12
MaxTokens int `json:"max_tokens"`
33
MaxTokens int `json:"max_tokens"`
13
System string `json:"system,omitempty"`
34
System string `json:"system,omitempty"`
14
Temperature float64 `json:"temperature"`
35
Temperature float64 `json:"temperature"`
15
Tools []Tool `json:"tools,omitempty"`
36
Tools []Tool `json:"tools,omitempty"`
16
}
37
}
17
38
18
type ContentBlock struct {
39
type ContentBlock struct {
19
// Type is "text", "document", "tool_use", or "tool_result".
40
// Type is "text", "document", "tool_use", or "tool_result".
20
Type string `json:"type"`
41
Type string `json:"type"`
21
42
22
// Text is for Type="text"
43
// Text is for Type="text"
23
Text string `json:"text,omitempty"`
44
Text string `json:"text,omitempty"`
24
45
25
// Source is for Type="document". MediaType disambiguates
46
// Source is for Type="document". MediaType disambiguates
26
// between image attachments and PDF attachments.
47
// between image attachments and PDF attachments.
27
Source *ContentSource `json:"source,omitempty"`
48
Source *ContentSource `json:"source,omitempty"`
28
49
29
// ID, Name, and Input are for Type="tool_use".
50
// ID, Name, and Input are for Type="tool_use".
30
ID string `json:"id,omitempty"`
51
ID string `json:"id,omitempty"`
31
Name string `json:"name,omitempty"`
52
Name string `json:"name,omitempty"`
32
Input interface{} `json:"input,omitempty"`
53
Input interface{} `json:"input,omitempty"`
33
54
34
// ToolUseID, Content, and Output are for
55
// ToolUseID, Content, and Output are for
35
// Type="tool_result".
56
// Type="tool_result".
36
ToolUseID string `json:"tool_use_id,omitempty"`
57
ToolUseID string `json:"tool_use_id,omitempty"`
37
Content string `json:"content,omitempty"`
58
Content string `json:"content,omitempty"`
38
Output string `json:"output,omitempty"`
59
Output string `json:"output,omitempty"`
39
}
60
}
40
61
41
type ContentSource struct {
62
type ContentSource struct {
42
// Type can only be "base64".
63
// Type can only be "base64".
43
Type string `json:"type"`
64
Type string `json:"type"`
44
65
45
// MediaType can be one of:
66
// MediaType can be one of:
46
// - "image/jpeg",
67
// - "image/jpeg",
47
// - "image/png",
68
// - "image/png",
48
// - "image/gif",
69
// - "image/gif",
49
// - "image/webp", or
70
// - "image/webp", or
50
// - "application/pdf".
71
// - "application/pdf".
51
// Whether a particular MIME is accepted depends on the
72
// Whether a particular MIME is accepted depends on the
52
// backing model; the func registry's CanSeeDocuments
73
// backing model; the func registry's CanSeeDocuments
53
// capability gates attachments at all but does not split
74
// capability gates attachments at all but does not split
54
// by MIME.
75
// by MIME.
55
MediaType string `json:"media_type,omitempty"`
76
MediaType string `json:"media_type,omitempty"`
56
77
57
Data string `json:"data,omitempty"`
78
Data string `json:"data,omitempty"`
58
}
79
}
59
80
60
type Usage struct {
81
type Usage struct {
61
InputTokens int `json:"input_tokens,omitempty"`
82
InputTokens int `json:"input_tokens,omitempty"`
62
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
83
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
63
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
84
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
64
OutputTokens int `json:"output_tokens,omitempty"`
85
OutputTokens int `json:"output_tokens,omitempty"`
65
}
86
}
66
87
67
type ErrorResponse struct {
88
type ErrorResponse struct {
68
Type string `json:"type"`
89
Type string `json:"type"`
69
Message string `json:"message"`
90
Message string `json:"message"`
70
}
91
}
71
92
72
type MessagesResponse struct {
93
type MessagesResponse struct {
73
Id string `json:"id"`
94
Id string `json:"id"`
74
Type string `json:"type"`
95
Type string `json:"type"`
75
Role string `json:"role"`
96
Role string `json:"role"`
76
Content []ContentBlock `json:"content"`
97
Content []ContentBlock `json:"content"`
77
Model string `json:"model"`
98
Model string `json:"model"`
78
StopReason *string `json:"stop_reason,omitempty"`
99
StopReason *string `json:"stop_reason,omitempty"`
79
StopSequence *string `json:"stop_sequence,omitempty"`
100
StopSequence *string `json:"stop_sequence,omitempty"`
80
Usage Usage `json:"usage"`
101
Usage Usage `json:"usage"`
81
Error *ErrorResponse `json:"error,omitempty"`
102
Error *ErrorResponse `json:"error,omitempty"`
82
}
103
}
83
104
84
type Tool struct {
105
type Tool struct {
85
Type string `json:"type"`
106
Type string `json:"type"`
86
Function *ToolFunction `json:"function"`
107
Function *ToolFunction `json:"function"`
87
}
108
}
88
109
89
type ToolFunction struct {
110
type ToolFunction struct {
90
Name string `json:"name"`
111
Name string `json:"name"`
91
Description string `json:"description"`
112
Description string `json:"description"`
92
InputSchema interface{} `json:"input_schema"`
113
InputSchema interface{} `json:"input_schema"`
93
}
114
}
1
package api
1
package api
2
2
3
// These tests pin the wire format of ContentBlock so callers across
3
// These tests pin the wire format of ContentBlock so callers across
4
// //funky, //ithfm, and other consumers can rely on stable JSON
4
// //funky, //ithfm, and other consumers can rely on stable JSON
5
// shapes. They fail loudly if the JSON tags or field names ever
5
// shapes. They fail loudly if the JSON tags or field names ever
6
// drift.
6
// drift.
7
7
8
import "encoding/json"
8
import "encoding/json"
9
import "strings"
9
import "strings"
10
import "testing"
10
import "testing"
11
11
12
func TestDocumentBlockMarshal(t *testing.T) {
12
func TestDocumentBlockMarshal(t *testing.T) {
13
block := ContentBlock{
13
block := ContentBlock{
14
Type: "document",
14
Type: "document",
15
Source: &ContentSource{
15
Source: &ContentSource{
16
Type: "base64",
16
Type: "base64",
17
MediaType: "application/pdf",
17
MediaType: "application/pdf",
18
Data: "JVBERi0xLjQK",
18
Data: "JVBERi0xLjQK",
19
},
19
},
20
}
20
}
21
buf, err := json.Marshal(block)
21
buf, err := json.Marshal(block)
22
if err != nil {
22
if err != nil {
23
t.Fatalf("marshal: %v", err)
23
t.Fatalf("marshal: %v", err)
24
}
24
}
25
want := `{"type":"document","source":{"type":"base64","media_type":"application/pdf","data":"JVBERi0xLjQK"}}`
25
want := `{"type":"document","source":{"type":"base64","media_type":"application/pdf","data":"JVBERi0xLjQK"}}`
26
if string(buf) != want {
26
if string(buf) != want {
27
t.Errorf(
27
t.Errorf(
28
"marshal mismatch\nwant: %s\n got: %s", want, buf)
28
"marshal mismatch\nwant: %s\n got: %s", want, buf)
29
}
29
}
30
}
30
}
31
31
32
func TestDocumentBlockUnmarshal(t *testing.T) {
32
func TestDocumentBlockUnmarshal(t *testing.T) {
33
in := []byte(
33
in := []byte(
34
`{"type":"document","source":{"type":"base64",` +
34
`{"type":"document","source":{"type":"base64",` +
35
`"media_type":"application/pdf","data":"JVBERi0xLjQK"}}`)
35
`"media_type":"application/pdf","data":"JVBERi0xLjQK"}}`)
36
var block ContentBlock
36
var block ContentBlock
37
if err := json.Unmarshal(in, &block); err != nil {
37
if err := json.Unmarshal(in, &block); err != nil {
38
t.Fatalf("unmarshal: %v", err)
38
t.Fatalf("unmarshal: %v", err)
39
}
39
}
40
if block.Type != "document" {
40
if block.Type != "document" {
41
t.Errorf("Type = %q, want %q", block.Type, "document")
41
t.Errorf("Type = %q, want %q", block.Type, "document")
42
}
42
}
43
if block.Source == nil {
43
if block.Source == nil {
44
t.Fatalf("Source is nil")
44
t.Fatalf("Source is nil")
45
}
45
}
46
if block.Source.MediaType != "application/pdf" {
46
if block.Source.MediaType != "application/pdf" {
47
t.Errorf(
47
t.Errorf(
48
"MediaType = %q, want application/pdf",
48
"MediaType = %q, want application/pdf",
49
block.Source.MediaType)
49
block.Source.MediaType)
50
}
50
}
51
if block.Source.Data != "JVBERi0xLjQK" {
51
if block.Source.Data != "JVBERi0xLjQK" {
52
t.Errorf("Data = %q", block.Source.Data)
52
t.Errorf("Data = %q", block.Source.Data)
53
}
53
}
54
}
54
}
55
55
56
func TestDocumentBlockImageMimeRoundTrip(t *testing.T) {
56
func TestDocumentBlockImageMimeRoundTrip(t *testing.T) {
57
// A "document" with an image/* MIME is how the wire format
57
// A "document" with an image/* MIME is how the wire format
58
// carries what used to live under Type:"image".
58
// carries what used to live under Type:"image".
59
block := ContentBlock{
59
block := ContentBlock{
60
Type: "document",
60
Type: "document",
61
Source: &ContentSource{
61
Source: &ContentSource{
62
Type: "base64",
62
Type: "base64",
63
MediaType: "image/png",
63
MediaType: "image/png",
64
Data: "iVBORw0KGgo=",
64
Data: "iVBORw0KGgo=",
65
},
65
},
66
}
66
}
67
buf, err := json.Marshal(block)
67
buf, err := json.Marshal(block)
68
if err != nil {
68
if err != nil {
69
t.Fatalf("marshal: %v", err)
69
t.Fatalf("marshal: %v", err)
70
}
70
}
71
if !strings.Contains(string(buf), `"type":"document"`) {
71
if !strings.Contains(string(buf), `"type":"document"`) {
72
t.Errorf("expected type:document, got %s", buf)
72
t.Errorf("expected type:document, got %s", buf)
73
}
73
}
74
if !strings.Contains(string(buf), `"media_type":"image/png"`) {
74
if !strings.Contains(string(buf), `"media_type":"image/png"`) {
75
t.Errorf("expected media_type image/png, got %s", buf)
75
t.Errorf("expected media_type image/png, got %s", buf)
76
}
76
}
77
var got ContentBlock
77
var got ContentBlock
78
if err := json.Unmarshal(buf, &got); err != nil {
78
if err := json.Unmarshal(buf, &got); err != nil {
79
t.Fatalf("unmarshal: %v", err)
79
t.Fatalf("unmarshal: %v", err)
80
}
80
}
81
if got.Type != "document" {
81
if got.Type != "document" {
82
t.Errorf("Type = %q", got.Type)
82
t.Errorf("Type = %q", got.Type)
83
}
83
}
84
if got.Source == nil ||
84
if got.Source == nil ||
85
got.Source.MediaType != "image/png" {
85
got.Source.MediaType != "image/png" {
86
t.Errorf("Source mismatch: %+v", got.Source)
86
t.Errorf("Source mismatch: %+v", got.Source)
87
}
87
}
88
}
88
}
89
89
90
func TestTextBlockUnchanged(t *testing.T) {
90
func TestTextBlockUnchanged(t *testing.T) {
91
// Documents the unchanged text-block wire format.
91
// Documents the unchanged text-block wire format.
92
block := ContentBlock{Type: "text", Text: "hello"}
92
block := ContentBlock{Type: "text", Text: "hello"}
93
buf, err := json.Marshal(block)
93
buf, err := json.Marshal(block)
94
if err != nil {
94
if err != nil {
95
t.Fatalf("marshal: %v", err)
95
t.Fatalf("marshal: %v", err)
96
}
96
}
97
if string(buf) != `{"type":"text","text":"hello"}` {
97
if string(buf) != `{"type":"text","text":"hello"}` {
98
t.Errorf("unexpected text-block JSON: %s", buf)
98
t.Errorf("unexpected text-block JSON: %s", buf)
99
}
99
}
100
}
100
}
101
102
func TestNewDocumentBlock(t *testing.T) {
103
png := []byte{
104
0x89, 'P', 'N', 'G', 0x0D, 0x0A, 0x1A, 0x0A,
105
// padding so DetectContentType has enough to work with
106
0, 0, 0, 0, 0, 0, 0, 0,
107
}
108
blk := NewDocumentBlock(png)
109
if blk.Type != "document" {
110
t.Errorf("Type = %q, want document", blk.Type)
111
}
112
if blk.Source == nil {
113
t.Fatal("Source is nil")
114
}
115
if blk.Source.Type != "base64" {
116
t.Errorf("Source.Type = %q, want base64",
117
blk.Source.Type)
118
}
119
if blk.Source.MediaType != "image/png" {
120
t.Errorf("MediaType = %q, want image/png",
121
blk.Source.MediaType)
122
}
123
// The encoded bytes round-trip back to the original.
124
buf, err := json.Marshal(blk)
125
if err != nil {
126
t.Fatalf("marshal: %v", err)
127
}
128
if !strings.Contains(string(buf), `"type":"document"`) {
129
t.Errorf("marshaled JSON missing type:document: %s", buf)
130
}
131
}
1
package main
1
package main
2
2
3
import "encoding/base64"
4
import "encoding/json"
3
import "encoding/json"
5
import "flag"
4
import "flag"
6
import "io/ioutil"
5
import "io/ioutil"
7
import "log"
6
import "log"
8
import "net/http"
9
import "os"
7
import "os"
10
import "path"
8
import "path"
11
import "sort"
9
import "sort"
12
import "strings"
10
import "strings"
13
11
14
import "oscarkilo.com/klex-git/api"
12
import "oscarkilo.com/klex-git/api"
15
import "oscarkilo.com/klex-git/config"
13
import "oscarkilo.com/klex-git/config"
16
14
17
var dir = flag.String("dir", ".", "Directory to scan and write to.")
15
var dir = flag.String("dir", ".", "Directory to scan and write to.")
18
var model = flag.String("model", "Gemini 3 Pro", "")
16
var model = flag.String("model", "Gemini 3 Pro", "")
19
var format = flag.String("format", "text", "text|json|jsonindent")
17
var format = flag.String("format", "text", "text|json|jsonindent")
20
var dry_run = flag.Bool("dry_run", false, "")
18
var dry_run = flag.Bool("dry_run", false, "")
21
var debug = flag.Bool("debug", false, "")
19
var debug = flag.Bool("debug", false, "")
22
20
23
type Case struct {
21
type Case struct {
24
Name string
22
Name string
25
Before []string
23
Before []string
26
After string
24
After string
27
Request *api.MessagesRequest
25
Request *api.MessagesRequest
28
}
26
}
29
27
30
func scanForCases() []Case {
28
func scanForCases() []Case {
31
entries, err := os.ReadDir(*dir)
29
entries, err := os.ReadDir(*dir)
32
if err != nil {
30
if err != nil {
33
log.Fatalf("Failed to read dir %s: %v", *dir, err)
31
log.Fatalf("Failed to read dir %s: %v", *dir, err)
34
}
32
}
35
33
36
before := make(map[string][]string)
34
before := make(map[string][]string)
37
after := make(map[string]string)
35
after := make(map[string]string)
38
for _, entry := range entries {
36
for _, entry := range entries {
39
if entry.IsDir() {
37
if entry.IsDir() {
40
continue
38
continue
41
}
39
}
42
if entry.Name() == "system_prompt.txt" {
40
if entry.Name() == "system_prompt.txt" {
43
continue
41
continue
44
}
42
}
45
chunks := strings.Split(entry.Name(), ".")
43
chunks := strings.Split(entry.Name(), ".")
46
if len(chunks) < 2 {
44
if len(chunks) < 2 {
47
continue
45
continue
48
}
46
}
49
name := strings.Join(chunks[0:len(chunks)-1], ".")
47
name := strings.Join(chunks[0:len(chunks)-1], ".")
50
suffix := chunks[len(chunks)-1]
48
suffix := chunks[len(chunks)-1]
51
switch suffix {
49
switch suffix {
52
case "txt", "json", "jpg", "jpeg", "png", "webp":
50
case "txt", "json", "jpg", "jpeg", "png", "webp":
53
before[name] = append(before[name], suffix)
51
before[name] = append(before[name], suffix)
54
case "out":
52
case "out":
55
after[name] = suffix
53
after[name] = suffix
56
}
54
}
57
}
55
}
58
56
59
var cases []Case
57
var cases []Case
60
for name, b := range before {
58
for name, b := range before {
61
sort.Slice(b, func(i, j int) bool {
59
sort.Slice(b, func(i, j int) bool {
62
if b[i] == "txt" && b[j] != "txt" {
60
if b[i] == "txt" && b[j] != "txt" {
63
return true
61
return true
64
}
62
}
65
if b[i] != "txt" && b[j] == "txt" {
63
if b[i] != "txt" && b[j] == "txt" {
66
return false
64
return false
67
}
65
}
68
return b[i] < b[j]
66
return b[i] < b[j]
69
})
67
})
70
cases = append(cases, Case{
68
cases = append(cases, Case{
71
Name: name,
69
Name: name,
72
Before: b,
70
Before: b,
73
After: after[name],
71
After: after[name],
74
})
72
})
75
}
73
}
76
74
77
sort.Slice(cases, func(i, j int) bool {
75
sort.Slice(cases, func(i, j int) bool {
78
a, b := cases[i], cases[j]
76
a, b := cases[i], cases[j]
79
if a.After != "" && b.After == "" {
77
if a.After != "" && b.After == "" {
80
return true
78
return true
81
}
79
}
82
if a.After == "" && b.After != "" {
80
if a.After == "" && b.After != "" {
83
return false
81
return false
84
}
82
}
85
return a.Name < b.Name
83
return a.Name < b.Name
86
})
84
})
87
85
88
num_examples := 0
86
num_examples := 0
89
for ; num_examples < len(cases); num_examples++ {
87
for ; num_examples < len(cases); num_examples++ {
90
if cases[num_examples].After == "" {
88
if cases[num_examples].After == "" {
91
break
89
break
92
}
90
}
93
}
91
}
94
log.Printf(
92
log.Printf(
95
"%s:\n num_examples = %d\n num_inputs = %d",
93
"%s:\n num_examples = %d\n num_inputs = %d",
96
*dir,
94
*dir,
97
num_examples,
95
num_examples,
98
len(cases) - num_examples,
96
len(cases) - num_examples,
99
)
97
)
100
98
101
return cases
99
return cases
102
}
100
}
103
101
104
func readFile(fname string) []byte {
102
func readFile(fname string) []byte {
105
data, err := ioutil.ReadFile(path.Join(*dir, fname))
103
data, err := ioutil.ReadFile(path.Join(*dir, fname))
106
if err != nil {
104
if err != nil {
107
log.Fatalf("Failed to read file %s: %v", fname, err)
105
log.Fatalf("Failed to read file %s: %v", fname, err)
108
}
106
}
109
return data
107
return data
110
}
108
}
111
109
112
func writeFile(fname string, data []byte) error {
110
func writeFile(fname string, data []byte) error {
113
return ioutil.WriteFile(path.Join(*dir, fname), data, 0644)
111
return ioutil.WriteFile(path.Join(*dir, fname), data, 0644)
114
}
112
}
115
113
116
func copyRequest(req api.MessagesRequest) *api.MessagesRequest {
114
func copyRequest(req api.MessagesRequest) *api.MessagesRequest {
117
bytes, err := json.Marshal(req)
115
bytes, err := json.Marshal(req)
118
if err != nil {
116
if err != nil {
119
log.Fatalf("Failed to copy request: %+v", err)
117
log.Fatalf("Failed to copy request: %+v", err)
120
}
118
}
121
req2 := &api.MessagesRequest{}
119
req2 := &api.MessagesRequest{}
122
err = json.Unmarshal(bytes, req2)
120
err = json.Unmarshal(bytes, req2)
123
if err != nil {
121
if err != nil {
124
log.Fatalf("Failed to copy request: %+v", err)
122
log.Fatalf("Failed to copy request: %+v", err)
125
}
123
}
126
return req2
124
return req2
127
}
125
}
128
126
129
func main() {
127
func main() {
130
flag.Parse()
128
flag.Parse()
131
129
132
// Find the API keys and configure a Klex client.
130
// Find the API keys and configure a Klex client.
133
config, err := config.ReadConfig()
131
config, err := config.ReadConfig()
134
if err != nil {
132
if err != nil {
135
log.Fatalf("Failed to read config: %v", err)
133
log.Fatalf("Failed to read config: %v", err)
136
}
134
}
137
client := api.NewClient(config.KlexUrl, config.ApiKey)
135
client := api.NewClient(config.KlexUrl, config.ApiKey)
138
if client == nil {
136
if client == nil {
139
log.Fatalf("Failed to create Klex client")
137
log.Fatalf("Failed to create Klex client")
140
}
138
}
141
139
142
// Create MessagesRequest objects.
140
// Create MessagesRequest objects.
143
req := api.MessagesRequest{
141
req := api.MessagesRequest{
144
Model: *model,
142
Model: *model,
145
System: string(readFile("system_prompt.txt")),
143
System: string(readFile("system_prompt.txt")),
146
}
144
}
147
cases := scanForCases()
145
cases := scanForCases()
148
for i, c := range cases {
146
for i, c := range cases {
149
user := api.ChatMessage{Role: "user"}
147
user := api.ChatMessage{Role: "user"}
150
text := "Case name: " + c.Name + "\n\n"
148
text := "Case name: " + c.Name + "\n\n"
151
for _, suffix := range c.Before {
149
for _, suffix := range c.Before {
152
switch suffix {
150
switch suffix {
153
case "txt":
151
case "txt":
154
text += string(readFile(c.Name + ".txt"))
152
text += string(readFile(c.Name + ".txt"))
155
case "json":
153
case "json":
156
text += "\n\n```json\n"
154
text += "\n\n```json\n"
157
text += string(readFile(c.Name + ".json"))
155
text += string(readFile(c.Name + ".json"))
158
text += "\n```\n"
156
text += "\n```\n"
159
case "jpg", "jpeg", "png", "webp":
157
case "jpg", "jpeg", "png", "webp":
160
bytes := readFile(c.Name + "." + suffix)
158
bytes := readFile(c.Name + "." + suffix)
161
user.Content = append(user.Content,
159
user.Content = append(user.Content,
162
eocument
160
eocument
163
Source: &api.ContentSource{
164
Type: "base64",
165
MediaType: http.DetectContentType(bytes),
166
Data: base64.StdEncoding.EncodeToString(bytes),
167
},
168
})
169
default:
161
default:
170
log.Fatalf("Unsupported suffix %s in case %s", suffix, c.Name)
162
log.Fatalf("Unsupported suffix %s in case %s", suffix, c.Name)
171
}
163
}
172
}
164
}
173
user.Content = append(user.Content, api.ContentBlock{
165
user.Content = append(user.Content, api.ContentBlock{
174
Type: "text",
166
Type: "text",
175
Text: text,
167
Text: text,
176
})
168
})
177
req.Messages = append(req.Messages, user)
169
req.Messages = append(req.Messages, user)
178
if c.After != "" {
170
if c.After != "" {
179
asst := api.ChatMessage{Role: "assistant"}
171
asst := api.ChatMessage{Role: "assistant"}
180
asst.Content = append(asst.Content, api.ContentBlock{
172
asst.Content = append(asst.Content, api.ContentBlock{
181
Type: "text",
173
Type: "text",
182
Text: string(readFile(c.Name + ".out")),
174
Text: string(readFile(c.Name + ".out")),
183
})
175
})
184
req.Messages = append(req.Messages, asst)
176
req.Messages = append(req.Messages, asst)
185
} else {
177
} else {
186
cases[i].Request = copyRequest(req)
178
cases[i].Request = copyRequest(req)
187
req.Messages = req.Messages[:len(req.Messages)-1]
179
req.Messages = req.Messages[:len(req.Messages)-1]
188
}
180
}
189
}
181
}
190
182
191
if *dry_run {
183
if *dry_run {
192
log.Printf("Dry run; not sending requests.")
184
log.Printf("Dry run; not sending requests.")
193
return
185
return
194
}
186
}
195
187
196
// Send 'em.
188
// Send 'em.
197
// TODO: parallelize
189
// TODO: parallelize
198
for _, c := range cases {
190
for _, c := range cases {
199
if c.Request == nil {
191
if c.Request == nil {
200
continue
192
continue
201
}
193
}
202
if *debug {
194
if *debug {
203
log.Printf("Case %s: sending request:", c.Name)
195
log.Printf("Case %s: sending request:", c.Name)
204
enc := json.NewEncoder(os.Stderr)
196
enc := json.NewEncoder(os.Stderr)
205
enc.SetIndent("", " ")
197
enc.SetIndent("", " ")
206
enc.Encode(c.Request)
198
enc.Encode(c.Request)
207
}
199
}
208
res, err := client.Messages(*c.Request)
200
res, err := client.Messages(*c.Request)
209
if err != nil {
201
if err != nil {
210
log.Printf("Case %s: request failed: %+v", c.Name, err)
202
log.Printf("Case %s: request failed: %+v", c.Name, err)
211
continue
203
continue
212
}
204
}
213
if len(res.Content) != 1 {
205
if len(res.Content) != 1 {
214
log.Printf("Case %s: empty response", c.Name)
206
log.Printf("Case %s: empty response", c.Name)
215
continue
207
continue
216
}
208
}
217
c0 := res.Content[0]
209
c0 := res.Content[0]
218
if c0.Type != "text" {
210
if c0.Type != "text" {
219
log.Printf("Case %s: Content[0].Type = %s", c.Name, c0.Type)
211
log.Printf("Case %s: Content[0].Type = %s", c.Name, c0.Type)
220
continue
212
continue
221
}
213
}
222
err = writeFile(c.Name + ".out", append([]byte(c0.Text), '\n'))
214
err = writeFile(c.Name + ".out", append([]byte(c0.Text), '\n'))
223
if err != nil {
215
if err != nil {
224
log.Printf("Case %s: failed to write %s.out: %+v", c.Name, c.Name, err)
216
log.Printf("Case %s: failed to write %s.out: %+v", c.Name, c.Name, err)
225
continue
217
continue
226
}
218
}
227
log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
219
log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
228
}
220
}
229
}
221
}
1
package main
1
package main
2
2
3
// This binary runs one LLM inference on one input.
3
// This binary runs one LLM inference on one input.
4
4
5
import "encoding/base64"
6
import "flag"
5
import "flag"
7
import "fmt"
6
import "fmt"
8
import "encoding/json"
7
import "encoding/json"
9
import "io/ioutil"
8
import "io/ioutil"
10
import "log"
9
import "log"
11
import "net/http"
12
import "os"
10
import "os"
13
import "strings"
11
import "strings"
14
12
15
import "oscarkilo.com/klex-git/api"
13
import "oscarkilo.com/klex-git/api"
16
import "oscarkilo.com/klex-git/config"
14
import "oscarkilo.com/klex-git/config"
17
15
18
var model = flag.String("model", "", "overrides .Model, if non-empty")
16
var model = flag.String("model", "", "overrides .Model, if non-empty")
19
var system = flag.String("system_file", "", "overrides .System, if non-empty")
17
var system = flag.String("system_file", "", "overrides .System, if non-empty")
20
var prompt = flag.String("prompt_file", "", "appends to .Messages")
18
var prompt = flag.String("prompt_file", "", "appends to .Messages")
21
var attach = flag.String("attach", "",
19
var attach = flag.String("attach", "",
22
"path to a file (image or PDF) to attach to the prompt")
20
"path to a file (image or PDF) to attach to the prompt")
23
var format = flag.String("format", "text", "text|json|jsonindent")
21
var format = flag.String("format", "text", "text|json|jsonindent")
24
var fastFail = flag.Bool("fast-fail", true,
22
var fastFail = flag.Bool("fast-fail", true,
25
"before sending, fetch the function's llm2 capabilities and "+
23
"before sending, fetch the function's llm2 capabilities and "+
26
"fail if the attached MIME type isn't supported. Set to "+
24
"fail if the attached MIME type isn't supported. Set to "+
27
"false in tight loops (e.g. MapReduce) where the extra "+
25
"false in tight loops (e.g. MapReduce) where the extra "+
28
"preflight HTTP round-trip per call is unwanted.")
26
"preflight HTTP round-trip per call is unwanted.")
29
27
30
// guessMimeType returns the MIME type inferred from file contents.
31
func guessMimeType(b []byte) string {
32
if len(b) > 512 {
33
b = b[:512]
34
}
35
return http.DetectContentType(b)
36
}
37
38
func main() {
28
func main() {
39
flag.Parse()
29
flag.Parse()
40
30
41
// Find the API keys and configure a Klex client.
31
// Find the API keys and configure a Klex client.
42
config, err := config.ReadConfig()
32
config, err := config.ReadConfig()
43
if err != nil {
33
if err != nil {
44
log.Fatalf("Failed to read config: %v", err)
34
log.Fatalf("Failed to read config: %v", err)
45
}
35
}
46
client := api.NewClient(config.KlexUrl, config.ApiKey)
36
client := api.NewClient(config.KlexUrl, config.ApiKey)
47
if client == nil {
37
if client == nil {
48
log.Fatalf("Failed to create Klex client")
38
log.Fatalf("Failed to create Klex client")
49
}
39
}
50
40
51
// Parse stdin as a MessagesRequest object, allowing empty input.
41
// Parse stdin as a MessagesRequest object, allowing empty input.
52
sin, err := ioutil.ReadAll(os.Stdin)
42
sin, err := ioutil.ReadAll(os.Stdin)
53
if err != nil {
43
if err != nil {
54
log.Fatalf("Failed to read stdin: %v", err)
44
log.Fatalf("Failed to read stdin: %v", err)
55
}
45
}
56
if len(sin) == 0 {
46
if len(sin) == 0 {
57
sin = []byte("{}")
47
sin = []byte("{}")
58
}
48
}
59
var req api.MessagesRequest
49
var req api.MessagesRequest
60
err = json.Unmarshal(sin, &req)
50
err = json.Unmarshal(sin, &req)
61
if err != nil {
51
if err != nil {
62
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
52
log.Fatalf("Failed to parse a MessagesRequest from stdin: %v", err)
63
}
53
}
64
54
65
// Use flags to override parts of the request.
55
// Use flags to override parts of the request.
66
if *model != "" {
56
if *model != "" {
67
req.Model = *model
57
req.Model = *model
68
}
58
}
69
if *system != "" {
59
if *system != "" {
70
s, err := ioutil.ReadFile(*system)
60
s, err := ioutil.ReadFile(*system)
71
if err != nil {
61
if err != nil {
72
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
62
log.Fatalf("Failed to read --system_file %s: %v", *system, err)
73
}
63
}
74
req.System = string(s)
64
req.System = string(s)
75
}
65
}
76
if *attach != "" && *prompt == "" {
66
if *attach != "" && *prompt == "" {
77
log.Fatalf("--attach requires a non-empty --prompt_file, too")
67
log.Fatalf("--attach requires a non-empty --prompt_file, too")
78
}
68
}
79
// Tracks the attachment MIME for the preflight below (empty if
69
// Tracks the attachment MIME for the preflight below (empty if
80
// no attachment was given).
70
// no attachment was given).
81
var attach_mime string
71
var attach_mime string
82
if *prompt != "" {
72
if *prompt != "" {
83
msg := api.ChatMessage{Role: "user"}
73
msg := api.ChatMessage{Role: "user"}
84
if *attach != "" {
74
if *attach != "" {
85
i, err := ioutil.ReadFile(*attach)
75
i, err := ioutil.ReadFile(*attach)
86
if err != nil {
76
if err != nil {
87
log.Fatalf("Failed to read --attach %s: %v", *attach, err)
77
log.Fatalf("Failed to read --attach %s: %v", *attach, err)
88
}
78
}
89
= eme(i)
79
= eme(i)
90
= .o.e
80
= .o.e
91
peoent
81
peoent
92
Source: &api.ContentSource{
93
Type: "base64",
94
MediaType: attach_mime,
95
Data: base64.StdEncoding.EncodeToString(i),
96
},
97
})
98
}
82
}
99
p, err := ioutil.ReadFile(*prompt)
83
p, err := ioutil.ReadFile(*prompt)
100
if err != nil {
84
if err != nil {
101
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
85
log.Fatalf("Failed to read --prompt_file %s: %v", *prompt, err)
102
}
86
}
103
msg.Content = append(msg.Content, api.ContentBlock{
87
msg.Content = append(msg.Content, api.ContentBlock{
104
Type: "text",
88
Type: "text",
105
Text: string(p),
89
Text: string(p),
106
})
90
})
107
req.Messages = append(req.Messages, msg)
91
req.Messages = append(req.Messages, msg)
108
}
92
}
109
93
110
// Pre-flight: catch unsupported attachment types before paying
94
// Pre-flight: catch unsupported attachment types before paying
111
// for the LLM call. Skip when --fast-fail=false (e.g. in
95
// for the LLM call. Skip when --fast-fail=false (e.g. in
112
// MapReduce loops that don't want one extra HTTP round-trip
96
// MapReduce loops that don't want one extra HTTP round-trip
113
// per call). No local MIME whitelist — server-side flags are
97
// per call). No local MIME whitelist — server-side flags are
114
// the source of truth.
98
// the source of truth.
115
if *fastFail && attach_mime != "" {
99
if *fastFail && attach_mime != "" {
116
preflight(client, req.Model, attach_mime)
100
preflight(client, req.Model, attach_mime)
117
}
101
}
118
102
119
// Get LLM output from Klex.
103
// Get LLM output from Klex.
120
res, err := client.Messages(req)
104
res, err := client.Messages(req)
121
if err != nil {
105
if err != nil {
122
log.Fatalf("Klex f() failure: %v", err)
106
log.Fatalf("Klex f() failure: %v", err)
123
}
107
}
124
108
125
// Print according to the --format flag.
109
// Print according to the --format flag.
126
out, err := formatResponse(res)
110
out, err := formatResponse(res)
127
if err != nil {
111
if err != nil {
128
log.Fatalf("Failed to format response: %v", err)
112
log.Fatalf("Failed to format response: %v", err)
129
}
113
}
130
fmt.Print(out)
114
fmt.Print(out)
131
}
115
}
132
116
133
func formatResponse(res *api.MessagesResponse) (string, error) {
117
func formatResponse(res *api.MessagesResponse) (string, error) {
134
switch *format {
118
switch *format {
135
case "text":
119
case "text":
136
var content []string
120
var content []string
137
for _, c := range res.Content {
121
for _, c := range res.Content {
138
if c.Type == "text" {
122
if c.Type == "text" {
139
content = append(content, c.Text + "\n")
123
content = append(content, c.Text + "\n")
140
}
124
}
141
}
125
}
142
return strings.Join(content, "\n"), nil
126
return strings.Join(content, "\n"), nil
143
case "json":
127
case "json":
144
buf, err := json.Marshal(res)
128
buf, err := json.Marshal(res)
145
return string(buf), err
129
return string(buf), err
146
case "jsonindent":
130
case "jsonindent":
147
buf, err := json.MarshalIndent(res, "", " ")
131
buf, err := json.MarshalIndent(res, "", " ")
148
return string(buf), err
132
return string(buf), err
149
default:
133
default:
150
return "", fmt.Errorf("Unsupported --format=%s", *format)
134
return "", fmt.Errorf("Unsupported --format=%s", *format)
151
}
135
}
152
}
136
}
153
137
154
// preflight fetches the function's llm2 capabilities and aborts
138
// preflight fetches the function's llm2 capabilities and aborts
155
// with a clear error if it can't accept the given attachment
139
// with a clear error if it can't accept the given attachment
156
// MIME type. Skips silently for MIME families Klex doesn't have
140
// MIME type. Skips silently for MIME families Klex doesn't have
157
// a capability flag for (i.e. anything that isn't image/* or
141
// a capability flag for (i.e. anything that isn't image/* or
158
// application/pdf) — the server-side adapter remains the final
142
// application/pdf) — the server-side adapter remains the final
159
// arbiter for those.
143
// arbiter for those.
160
func preflight(client *api.Client, model_name, mime_type string) {
144
func preflight(client *api.Client, model_name, mime_type string) {
161
resp, err := client.ListFuncs("latest")
145
resp, err := client.ListFuncs("latest")
162
if err != nil {
146
if err != nil {
163
log.Fatalf("Preflight ListFuncs failed (set "+
147
log.Fatalf("Preflight ListFuncs failed (set "+
164
"--fast-fail=false to bypass): %v", err)
148
"--fast-fail=false to bypass): %v", err)
165
}
149
}
166
var fn *api.Func
150
var fn *api.Func
167
for i := range resp.Funcs {
151
for i := range resp.Funcs {
168
if resp.Funcs[i].Name == model_name {
152
if resp.Funcs[i].Name == model_name {
169
fn = &resp.Funcs[i]
153
fn = &resp.Funcs[i]
170
break
154
break
171
}
155
}
172
}
156
}
173
if fn == nil {
157
if fn == nil {
174
log.Fatalf("Unknown model %q (set --fast-fail=false to "+
158
log.Fatalf("Unknown model %q (set --fast-fail=false to "+
175
"bypass)", model_name)
159
"bypass)", model_name)
176
}
160
}
177
if len(fn.Versions) == 0 || fn.Versions[len(fn.Versions)-1].LLM2 == nil {
161
if len(fn.Versions) == 0 || fn.Versions[len(fn.Versions)-1].LLM2 == nil {
178
log.Fatalf("Model %q has no llm2 config", model_name)
162
log.Fatalf("Model %q has no llm2 config", model_name)
179
}
163
}
180
llm := fn.Versions[len(fn.Versions)-1].LLM2
164
llm := fn.Versions[len(fn.Versions)-1].LLM2
181
switch {
165
switch {
182
case strings.HasPrefix(mime_type, "image/"):
166
case strings.HasPrefix(mime_type, "image/"):
183
if !llm.CanSeeImages {
167
if !llm.CanSeeImages {
184
log.Fatalf("Model %q does not accept images "+
168
log.Fatalf("Model %q does not accept images "+
185
"(can_see_images=false)", model_name)
169
"(can_see_images=false)", model_name)
186
}
170
}
187
case mime_type == "application/pdf":
171
case mime_type == "application/pdf":
188
if !llm.CanSeePDFs {
172
if !llm.CanSeePDFs {
189
log.Fatalf("Model %q does not accept PDFs "+
173
log.Fatalf("Model %q does not accept PDFs "+
190
"(can_see_pdfs=false)", model_name)
174
"(can_see_pdfs=false)", model_name)
191
}
175
}
192
}
176
}
193
}
177
}