1
package main
1
package main
2
2
3
import "encoding/base64"
3
import "encoding/base64"
4
import "encoding/json"
4
import "encoding/json"
5
import "flag"
5
import "flag"
6
import "io/ioutil"
6
import "io/ioutil"
7
import "log"
7
import "log"
8
import "net/http"
8
import "net/http"
9
import "os"
9
import "os"
10
import "path"
10
import "path"
11
import "sort"
11
import "sort"
12
import "strings"
12
import "strings"
13
13
14
import "oscarkilo.com/klex-git/api"
14
import "oscarkilo.com/klex-git/api"
15
import "oscarkilo.com/klex-git/config"
15
import "oscarkilo.com/klex-git/config"
16
16
17
var dir = flag.String("dir", ".", "Directory to scan and write to.")
17
var dir = flag.String("dir", ".", "Directory to scan and write to.")
18
var model = flag.String("model", "Gemini 3 Pro", "")
18
var model = flag.String("model", "Gemini 3 Pro", "")
19
var format = flag.String("format", "text", "text|json|jsonindent")
19
var format = flag.String("format", "text", "text|json|jsonindent")
20
var dry_run = flag.Bool("dry_run", false, "")
20
var dry_run = flag.Bool("dry_run", false, "")
21
var debug = flag.Bool("debug", false, "")
21
22
22
type Case struct {
23
type Case struct {
23
Name string
24
Name string
24
Before []string
25
Before []string
25
After string
26
After string
26
Request *api.MessagesRequest
27
Request *api.MessagesRequest
27
}
28
}
28
29
29
func scanForCases() []Case {
30
func scanForCases() []Case {
30
entries, err := os.ReadDir(*dir)
31
entries, err := os.ReadDir(*dir)
31
if err != nil {
32
if err != nil {
32
log.Fatalf("Failed to read dir %s: %v", *dir, err)
33
log.Fatalf("Failed to read dir %s: %v", *dir, err)
33
}
34
}
34
35
35
before := make(map[string][]string)
36
before := make(map[string][]string)
36
after := make(map[string]string)
37
after := make(map[string]string)
37
for _, entry := range entries {
38
for _, entry := range entries {
38
if entry.IsDir() {
39
if entry.IsDir() {
39
continue
40
continue
40
}
41
}
41
if entry.Name() == "system_prompt.txt" {
42
if entry.Name() == "system_prompt.txt" {
42
continue
43
continue
43
}
44
}
44
chunks := strings.Split(entry.Name(), ".")
45
chunks := strings.Split(entry.Name(), ".")
45
if len(chunks) < 2 {
46
if len(chunks) < 2 {
46
continue
47
continue
47
}
48
}
48
name := strings.Join(chunks[0:len(chunks)-1], ".")
49
name := strings.Join(chunks[0:len(chunks)-1], ".")
49
suffix := chunks[len(chunks)-1]
50
suffix := chunks[len(chunks)-1]
50
switch suffix {
51
switch suffix {
51
case "txt", "json", "jpg", "jpeg", "png":
52
case "txt", "json", "jpg", "jpeg", "png":
52
before[name] = append(before[name], suffix)
53
before[name] = append(before[name], suffix)
53
case "out":
54
case "out":
54
after[name] = suffix
55
after[name] = suffix
55
}
56
}
56
}
57
}
57
58
58
var cases []Case
59
var cases []Case
59
for name, b := range before {
60
for name, b := range before {
60
sort.Slice(b, func(i, j int) bool {
61
sort.Slice(b, func(i, j int) bool {
61
if b[i] == "txt" && b[j] != "txt" {
62
if b[i] == "txt" && b[j] != "txt" {
62
return true
63
return true
63
}
64
}
64
if b[i] != "txt" && b[j] == "txt" {
65
if b[i] != "txt" && b[j] == "txt" {
65
return false
66
return false
66
}
67
}
67
return b[i] < b[j]
68
return b[i] < b[j]
68
})
69
})
69
cases = append(cases, Case{
70
cases = append(cases, Case{
70
Name: name,
71
Name: name,
71
Before: b,
72
Before: b,
72
After: after[name],
73
After: after[name],
73
})
74
})
74
}
75
}
75
76
76
sort.Slice(cases, func(i, j int) bool {
77
sort.Slice(cases, func(i, j int) bool {
77
a, b := cases[i], cases[j]
78
a, b := cases[i], cases[j]
78
if a.After != "" && b.After == "" {
79
if a.After != "" && b.After == "" {
79
return true
80
return true
80
}
81
}
81
if a.After == "" && b.After != "" {
82
if a.After == "" && b.After != "" {
82
return false
83
return false
83
}
84
}
84
return a.Name < b.Name
85
return a.Name < b.Name
85
})
86
})
86
87
87
num_examples := 0
88
num_examples := 0
88
for ; num_examples < len(cases); num_examples++ {
89
for ; num_examples < len(cases); num_examples++ {
89
if cases[num_examples].After == "" {
90
if cases[num_examples].After == "" {
90
break
91
break
91
}
92
}
92
}
93
}
93
log.Printf(
94
log.Printf(
94
"%s:\n num_examples = %d\n num_inputs = %d",
95
"%s:\n num_examples = %d\n num_inputs = %d",
95
*dir,
96
*dir,
96
num_examples,
97
num_examples,
97
len(cases) - num_examples,
98
len(cases) - num_examples,
98
)
99
)
99
100
100
return cases
101
return cases
101
}
102
}
102
103
103
func readFile(fname string) []byte {
104
func readFile(fname string) []byte {
104
data, err := ioutil.ReadFile(path.Join(*dir, fname))
105
data, err := ioutil.ReadFile(path.Join(*dir, fname))
105
if err != nil {
106
if err != nil {
106
log.Fatalf("Failed to read file %s: %v", fname, err)
107
log.Fatalf("Failed to read file %s: %v", fname, err)
107
}
108
}
108
return data
109
return data
109
}
110
}
110
111
111
func writeFile(fname string, data []byte) error {
112
func writeFile(fname string, data []byte) error {
112
return ioutil.WriteFile(path.Join(*dir, fname), data, 0644)
113
return ioutil.WriteFile(path.Join(*dir, fname), data, 0644)
113
}
114
}
114
115
115
func copyRequest(req api.MessagesRequest) *api.MessagesRequest {
116
func copyRequest(req api.MessagesRequest) *api.MessagesRequest {
116
bytes, err := json.Marshal(req)
117
bytes, err := json.Marshal(req)
117
if err != nil {
118
if err != nil {
118
log.Fatalf("Failed to copy request: %+v", err)
119
log.Fatalf("Failed to copy request: %+v", err)
119
}
120
}
120
= .saee
121
= .saee
122
err = json.Unmarshal(bytes, req2)
121
if err != nil {
123
if err != nil {
122
log.Fatalf("Failed to copy request: %+v", err)
124
log.Fatalf("Failed to copy request: %+v", err)
123
}
125
}
124
return req
126
return req
125
}
127
}
126
128
127
func main() {
129
func main() {
128
flag.Parse()
130
flag.Parse()
129
131
130
// Find the API keys and configure a Klex client.
132
// Find the API keys and configure a Klex client.
131
config, err := config.ReadConfig()
133
config, err := config.ReadConfig()
132
if err != nil {
134
if err != nil {
133
log.Fatalf("Failed to read config: %v", err)
135
log.Fatalf("Failed to read config: %v", err)
134
}
136
}
135
client := api.NewClient(config.KlexUrl, config.ApiKey)
137
client := api.NewClient(config.KlexUrl, config.ApiKey)
136
if client == nil {
138
if client == nil {
137
log.Fatalf("Failed to create Klex client")
139
log.Fatalf("Failed to create Klex client")
138
}
140
}
139
141
140
// Create MessagesRequest objects.
142
// Create MessagesRequest objects.
141
req := api.MessagesRequest{
143
req := api.MessagesRequest{
142
Model: *model,
144
Model: *model,
143
System: string(readFile("system_prompt.txt")),
145
System: string(readFile("system_prompt.txt")),
144
}
146
}
145
cases := scanForCases()
147
cases := scanForCases()
146
for i, c := range cases {
148
for i, c := range cases {
147
user := api.ChatMessage{Role: "user"}
149
user := api.ChatMessage{Role: "user"}
148
text := "Case name: " + c.Name + "\n\n"
150
text := "Case name: " + c.Name + "\n\n"
149
for _, suffix := range c.Before {
151
for _, suffix := range c.Before {
150
switch suffix {
152
switch suffix {
151
case "txt":
153
case "txt":
152
text += string(readFile(c.Name + ".txt"))
154
text += string(readFile(c.Name + ".txt"))
153
case "json":
155
case "json":
154
text += "\n\n```json\n"
156
text += "\n\n```json\n"
155
text += string(readFile(c.Name + ".json"))
157
text += string(readFile(c.Name + ".json"))
156
text += "\n```\n"
158
text += "\n```\n"
157
case "jpg", "jpeg", "png", "webp":
159
case "jpg", "jpeg", "png", "webp":
158
bytes := readFile(c.Name + "." + suffix)
160
bytes := readFile(c.Name + "." + suffix)
159
user.Content = append(user.Content, api.ContentBlock{
161
user.Content = append(user.Content, api.ContentBlock{
160
Type: "image",
162
Type: "image",
161
Source: &api.ContentSource{
163
Source: &api.ContentSource{
162
Type: "base64",
164
Type: "base64",
163
MediaType: http.DetectContentType(bytes),
165
MediaType: http.DetectContentType(bytes),
164
Data: base64.StdEncoding.EncodeToString(bytes),
166
Data: base64.StdEncoding.EncodeToString(bytes),
165
},
167
},
166
})
168
})
167
default:
169
default:
168
log.Fatalf("Unsupported suffix %s in case %s", suffix, c.Name)
170
log.Fatalf("Unsupported suffix %s in case %s", suffix, c.Name)
169
}
171
}
170
}
172
}
171
user.Content = append(user.Content, api.ContentBlock{
173
user.Content = append(user.Content, api.ContentBlock{
172
Type: "text",
174
Type: "text",
173
Text: text,
175
Text: text,
174
})
176
})
175
req.Messages = append(req.Messages, user)
177
req.Messages = append(req.Messages, user)
176
if c.After != "" {
178
if c.After != "" {
177
asst := api.ChatMessage{Role: "assistant"}
179
asst := api.ChatMessage{Role: "assistant"}
178
asst.Content = append(asst.Content, api.ContentBlock{
180
asst.Content = append(asst.Content, api.ContentBlock{
179
Type: "text",
181
Type: "text",
180
Text: string(readFile(c.Name + ".out")),
182
Text: string(readFile(c.Name + ".out")),
181
})
183
})
182
req.Messages = append(req.Messages, asst)
184
req.Messages = append(req.Messages, asst)
183
} else {
185
} else {
184
cases[i].Request = copyRequest(req)
186
cases[i].Request = copyRequest(req)
185
req.Messages = req.Messages[:len(req.Messages)-1]
187
req.Messages = req.Messages[:len(req.Messages)-1]
186
}
188
}
187
}
189
}
188
190
189
if *dry_run {
191
if *dry_run {
190
log.Printf("Dry run; not sending requests.")
192
log.Printf("Dry run; not sending requests.")
191
return
193
return
192
}
194
}
193
195
194
// Send 'em.
196
// Send 'em.
195
// TODO: parallelize
197
// TODO: parallelize
196
for _, c := range cases {
198
for _, c := range cases {
197
if c.Request == nil {
199
if c.Request == nil {
198
continue
200
continue
199
}
201
}
202
if *debug {
203
log.Printf("Case %s: sending request:", c.Name)
204
enc := json.NewEncoder(os.Stderr)
205
enc.SetIndent("", " ")
206
enc.Encode(c.Request)
207
}
200
res, err := client.Messages(*c.Request)
208
res, err := client.Messages(*c.Request)
201
if err != nil {
209
if err != nil {
202
log.Printf("Case %s: request failed: %+v", c.Name, err)
210
log.Printf("Case %s: request failed: %+v", c.Name, err)
203
continue
211
continue
204
}
212
}
205
if len(res.Content) != 1 {
213
if len(res.Content) != 1 {
206
log.Printf("Case %s: empty response", c.Name)
214
log.Printf("Case %s: empty response", c.Name)
207
continue
215
continue
208
}
216
}
209
c0 := res.Content[0]
217
c0 := res.Content[0]
210
if c0.Type != "text" {
218
if c0.Type != "text" {
211
log.Printf("Case %s: Content[0].Type = %s", c.Name, c0.Type)
219
log.Printf("Case %s: Content[0].Type = %s", c.Name, c0.Type)
212
continue
220
continue
213
}
221
}
214
err = writeFile(c.Name + ".out", append([]byte(c0.Text), '\n'))
222
err = writeFile(c.Name + ".out", append([]byte(c0.Text), '\n'))
215
if err != nil {
223
if err != nil {
216
log.Printf("Case %s: failed to write %s.out: %+v", c.Name, c.Name, err)
224
log.Printf("Case %s: failed to write %s.out: %+v", c.Name, c.Name, err)
217
continue
225
continue
218
}
226
}
219
log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
227
log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
220
}
228
}
221
}
229
}