code.oscarkilo.com/klex-git/exemplary/main.go

..
README.md
example/
main.go
main_test.go
package main

import "encoding/base64"
import "encoding/json"
import "flag"
import "io/ioutil"
import "log"
import "net/http"
import "os"
import "path"
import "sort"
import "strings"

import "oscarkilo.com/klex-git/api"
import "oscarkilo.com/klex-git/config"

var dir = flag.String("dir", ".", "Directory to scan and write to.")
var model = flag.String("model", "Gemini 3 Pro", "")
var format = flag.String("format", "text", "text|json|jsonindent")
var dry_run = flag.Bool("dry_run", false, "")
var debug = flag.Bool("debug", false, "")

type Case struct {
  Name string
  Before []string
  After  string
  Request *api.MessagesRequest
}

func scanForCases() []Case {
  entries, err := os.ReadDir(*dir)
  if err != nil {
    log.Fatalf("Failed to read dir %s: %v", *dir, err)
  }

  before := make(map[string][]string)
  after := make(map[string]string)
  for _, entry := range entries {
    if entry.IsDir() {
      continue
    }
    if entry.Name() == "system_prompt.txt" {
      continue
    }
    chunks := strings.Split(entry.Name(), ".")
    if len(chunks) < 2 {
      continue
    }
    name := strings.Join(chunks[0:len(chunks)-1], ".")
    suffix := chunks[len(chunks)-1]
    switch suffix {
    case "txt", "json", "jpg", "jpeg", "png", "webp":
      before[name] = append(before[name], suffix)
    case "out":
      after[name] = suffix
    }
  }

  var cases []Case
  for name, b := range before {
    sort.Slice(b, func(i, j int) bool {
      if b[i] == "txt" && b[j] != "txt" {
        return true
      }
      if b[i] != "txt" && b[j] == "txt" {
        return false
      }
      return b[i] < b[j]
    })
    cases = append(cases, Case{
      Name: name,
      Before: b,
      After: after[name],
    })
  }

  sort.Slice(cases, func(i, j int) bool {
    a, b := cases[i], cases[j]
    if a.After != "" && b.After == "" {
      return true
    }
    if a.After == "" && b.After != "" {
      return false
    }
    return a.Name < b.Name
  })

  num_examples := 0
  for ; num_examples < len(cases); num_examples++ {
    if cases[num_examples].After == "" {
      break
    }
  }
  log.Printf(
    "%s:\n  num_examples = %d\n  num_inputs   = %d",
    *dir,
    num_examples,
    len(cases) - num_examples,
  )

  return cases
}

func readFile(fname string) []byte {
  data, err := ioutil.ReadFile(path.Join(*dir, fname))
  if err != nil {
    log.Fatalf("Failed to read file %s: %v", fname, err)
  }
  return data
}

func writeFile(fname string, data []byte) error {
  return ioutil.WriteFile(path.Join(*dir, fname), data, 0644)
}

func copyRequest(req api.MessagesRequest) *api.MessagesRequest {
  bytes, err := json.Marshal(req)
  if err != nil {
    log.Fatalf("Failed to copy request: %+v", err)
  }
  req2 := &api.MessagesRequest{}
  err = json.Unmarshal(bytes, req2)
  if err != nil {
    log.Fatalf("Failed to copy request: %+v", err)
  }
  return req2
}

func main() {
  flag.Parse()

  // Find the API keys and configure a Klex client.
  config, err := config.ReadConfig()
  if err != nil {
    log.Fatalf("Failed to read config: %v", err)
  }
  client := api.NewClient(config.KlexUrl, config.ApiKey)
  if client == nil {
    log.Fatalf("Failed to create Klex client")
  }

  // Create MessagesRequest objects.
  req := api.MessagesRequest{
    Model: *model,
    System: string(readFile("system_prompt.txt")),
  }
  cases := scanForCases()
  for i, c := range cases {
    user := api.ChatMessage{Role: "user"}
    text := "Case name: " + c.Name + "\n\n"
    for _, suffix := range c.Before {
      switch suffix {
      case "txt":
        text += string(readFile(c.Name + ".txt"))
      case "json":
        text += "\n\n```json\n"
        text += string(readFile(c.Name + ".json"))
        text += "\n```\n"
      case "jpg", "jpeg", "png", "webp":
        bytes := readFile(c.Name + "." + suffix)
        user.Content = append(user.Content, api.ContentBlock{
          Type: "image",
          Source: &api.ContentSource{
            Type: "base64",
            MediaType: http.DetectContentType(bytes),
            Data: base64.StdEncoding.EncodeToString(bytes),
          },
        })
      default:
        log.Fatalf("Unsupported suffix %s in case %s", suffix, c.Name)
      }
    }
    user.Content = append(user.Content, api.ContentBlock{
      Type: "text",
      Text: text,
    })
    req.Messages = append(req.Messages, user)
    if c.After != "" {
      asst := api.ChatMessage{Role: "assistant"}
      asst.Content = append(asst.Content, api.ContentBlock{
        Type: "text",
        Text: string(readFile(c.Name + ".out")),
      })
      req.Messages = append(req.Messages, asst)
    } else {
      cases[i].Request = copyRequest(req)
      req.Messages = req.Messages[:len(req.Messages)-1]
    }
  }

  if *dry_run {
    log.Printf("Dry run; not sending requests.")
    return
  }

  // Send 'em.
  // TODO: parallelize
  for _, c := range cases {
    if c.Request == nil {
      continue
    }
    if *debug {
      log.Printf("Case %s: sending request:", c.Name)
      enc := json.NewEncoder(os.Stderr)
      enc.SetIndent("", "  ")
      enc.Encode(c.Request)
    }
    res, err := client.Messages(*c.Request)
    if err != nil {
      log.Printf("Case %s: request failed: %+v", c.Name, err)
      continue
    }
    if len(res.Content) != 1 {
      log.Printf("Case %s: empty response", c.Name)
      continue
    }
    c0 := res.Content[0]
    if c0.Type != "text" {
      log.Printf("Case %s: Content[0].Type = %s", c.Name, c0.Type)
      continue
    }
    err = writeFile(c.Name + ".out", append([]byte(c0.Text), '\n'))
    if err != nil {
      log.Printf("Case %s: failed to write %s.out: %+v", c.Name, c.Name, err)
      continue
    }
    log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
  }
}