package main
import "encoding/base64"
import "encoding/json"
import "flag"
import "io/ioutil"
import "log"
import "net/http"
import "os"
import "path"
import "sort"
import "strings"
import "oscarkilo.com/klex-git/api"
import "oscarkilo.com/klex-git/config"
var dir = flag.String("dir", ".", "Directory to scan and write to.")
var model = flag.String("model", "Gemini 3 Pro", "")
var format = flag.String("format", "text", "text|json|jsonindent")
var dry_run = flag.Bool("dry_run", false, "")
var debug = flag.Bool("debug", false, "")
type Case struct {
Name string
Before []string
After string
Request *api.MessagesRequest
}
func scanForCases() []Case {
entries, err := os.ReadDir(*dir)
if err != nil {
log.Fatalf("Failed to read dir %s: %v", *dir, err)
}
before := make(map[string][]string)
after := make(map[string]string)
for _, entry := range entries {
if entry.IsDir() {
continue
}
if entry.Name() == "system_prompt.txt" {
continue
}
chunks := strings.Split(entry.Name(), ".")
if len(chunks) < 2 {
continue
}
name := strings.Join(chunks[0:len(chunks)-1], ".")
suffix := chunks[len(chunks)-1]
switch suffix {
case "txt", "json", "jpg", "jpeg", "png", "webp":
before[name] = append(before[name], suffix)
case "out":
after[name] = suffix
}
}
var cases []Case
for name, b := range before {
sort.Slice(b, func(i, j int) bool {
if b[i] == "txt" && b[j] != "txt" {
return true
}
if b[i] != "txt" && b[j] == "txt" {
return false
}
return b[i] < b[j]
})
cases = append(cases, Case{
Name: name,
Before: b,
After: after[name],
})
}
sort.Slice(cases, func(i, j int) bool {
a, b := cases[i], cases[j]
if a.After != "" && b.After == "" {
return true
}
if a.After == "" && b.After != "" {
return false
}
return a.Name < b.Name
})
num_examples := 0
for ; num_examples < len(cases); num_examples++ {
if cases[num_examples].After == "" {
break
}
}
log.Printf(
"%s:\n num_examples = %d\n num_inputs = %d",
*dir,
num_examples,
len(cases) - num_examples,
)
return cases
}
func readFile(fname string) []byte {
data, err := ioutil.ReadFile(path.Join(*dir, fname))
if err != nil {
log.Fatalf("Failed to read file %s: %v", fname, err)
}
return data
}
func writeFile(fname string, data []byte) error {
return ioutil.WriteFile(path.Join(*dir, fname), data, 0644)
}
func copyRequest(req api.MessagesRequest) *api.MessagesRequest {
bytes, err := json.Marshal(req)
if err != nil {
log.Fatalf("Failed to copy request: %+v", err)
}
req2 := &api.MessagesRequest{}
err = json.Unmarshal(bytes, req2)
if err != nil {
log.Fatalf("Failed to copy request: %+v", err)
}
return req2
}
func main() {
flag.Parse()
// Find the API keys and configure a Klex client.
config, err := config.ReadConfig()
if err != nil {
log.Fatalf("Failed to read config: %v", err)
}
client := api.NewClient(config.KlexUrl, config.ApiKey)
if client == nil {
log.Fatalf("Failed to create Klex client")
}
// Create MessagesRequest objects.
req := api.MessagesRequest{
Model: *model,
System: string(readFile("system_prompt.txt")),
}
cases := scanForCases()
for i, c := range cases {
user := api.ChatMessage{Role: "user"}
text := "Case name: " + c.Name + "\n\n"
for _, suffix := range c.Before {
switch suffix {
case "txt":
text += string(readFile(c.Name + ".txt"))
case "json":
text += "\n\n```json\n"
text += string(readFile(c.Name + ".json"))
text += "\n```\n"
case "jpg", "jpeg", "png", "webp":
bytes := readFile(c.Name + "." + suffix)
user.Content = append(user.Content, api.ContentBlock{
Type: "image",
Source: &api.ContentSource{
Type: "base64",
MediaType: http.DetectContentType(bytes),
Data: base64.StdEncoding.EncodeToString(bytes),
},
})
default:
log.Fatalf("Unsupported suffix %s in case %s", suffix, c.Name)
}
}
user.Content = append(user.Content, api.ContentBlock{
Type: "text",
Text: text,
})
req.Messages = append(req.Messages, user)
if c.After != "" {
asst := api.ChatMessage{Role: "assistant"}
asst.Content = append(asst.Content, api.ContentBlock{
Type: "text",
Text: string(readFile(c.Name + ".out")),
})
req.Messages = append(req.Messages, asst)
} else {
cases[i].Request = copyRequest(req)
req.Messages = req.Messages[:len(req.Messages)-1]
}
}
if *dry_run {
log.Printf("Dry run; not sending requests.")
return
}
// Send 'em.
// TODO: parallelize
for _, c := range cases {
if c.Request == nil {
continue
}
if *debug {
log.Printf("Case %s: sending request:", c.Name)
enc := json.NewEncoder(os.Stderr)
enc.SetIndent("", " ")
enc.Encode(c.Request)
}
res, err := client.Messages(*c.Request)
if err != nil {
log.Printf("Case %s: request failed: %+v", c.Name, err)
continue
}
if len(res.Content) != 1 {
log.Printf("Case %s: empty response", c.Name)
continue
}
c0 := res.Content[0]
if c0.Type != "text" {
log.Printf("Case %s: Content[0].Type = %s", c.Name, c0.Type)
continue
}
err = writeFile(c.Name + ".out", append([]byte(c0.Text), '\n'))
if err != nil {
log.Printf("Case %s: failed to write %s.out: %+v", c.Name, c.Name, err)
continue
}
log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
}
}