code.oscarkilo.com/okg/exemplary.go

.gitignore
README.md
auth.go
authz.go
chat.go
chat/
client.go
config.go
embed.go
exemplary.go
go.mod
go.sum
group.go
internal/
klee/
klex.go
main.go
okg_test.go
one.go
pr.go
repo.go
who/
package main

import "encoding/json"
import "flag"
import "fmt"
import "log"
import "os"
import "path"
import "sort"
import "strings"

import "oscarkilo.com/klex-git/api"

// runExemplary runs few-shot LLM inference on a directory full
// of "<case>.<ext>" input files. Files ending in .out are
// treated as examples (paired with their non-.out siblings);
// the remaining inputs are sent to the LLM and the responses
// are written back as <case>.out files.
//
// Mirrors the (now-deprecated) `klex-git/exemplary` binary.
func runExemplary(cfg *Config, args []string) error {
  fs := flag.NewFlagSet("exemplary", flag.ContinueOnError)
  dir := fs.String("dir", ".",
    "directory to scan for cases and write outputs to")
  model := fs.String("model", "Gemini 3 Pro",
    "LLM model name")
  dryRun := fs.Bool("dry-run", false,
    "scan + build requests but don't send them")
  debug := fs.Bool("debug", false,
    "log each request body to stderr before sending")
  if err := fs.Parse(args); err != nil {
    return err
  }

  if cfg.ApiKey == "" {
    return fmt.Errorf(
      "no API key — run `okg auth login --key sk-...`")
  }
  client := newKlexClient(cfg)

  // Build a request that starts with the system prompt and
  // grows by one user message per case (plus an assistant
  // message for examples). Inputs (cases without .out) get a
  // snapshot of the request taken just before their user
  // message would otherwise be added to the running prefix.
  sysBytes, err := os.ReadFile(
    path.Join(*dir, "system_prompt.txt"))
  if err != nil {
    return fmt.Errorf("read system_prompt.txt: %v", err)
  }
  req := api.MessagesRequest{
    Model:  *model,
    System: string(sysBytes),
  }

  cases, err := scanForCases(*dir)
  if err != nil {
    return err
  }

  for i, c := range cases {
    user := api.ChatMessage{Role: "user"}
    text := "Case name: " + c.Name + "\n\n"
    for _, suffix := range c.Before {
      switch suffix {
      case "txt":
        b, err := os.ReadFile(
          path.Join(*dir, c.Name+".txt"))
        if err != nil {
          return fmt.Errorf(
            "read %s.txt: %v", c.Name, err)
        }
        text += string(b)
      case "json":
        b, err := os.ReadFile(
          path.Join(*dir, c.Name+".json"))
        if err != nil {
          return fmt.Errorf(
            "read %s.json: %v", c.Name, err)
        }
        text += "\n\n```json\n" + string(b) + "\n```\n"
      case "jpg", "jpeg", "png", "webp":
        b, err := os.ReadFile(
          path.Join(*dir, c.Name+"."+suffix))
        if err != nil {
          return fmt.Errorf(
            "read %s.%s: %v", c.Name, suffix, err)
        }
        user.Content = append(user.Content,
          api.NewDocumentBlock(b))
      default:
        return fmt.Errorf(
          "unsupported suffix %s in case %s",
          suffix, c.Name)
      }
    }
    user.Content = append(user.Content, api.ContentBlock{
      Type: "text",
      Text: text,
    })
    req.Messages = append(req.Messages, user)
    if c.After != "" {
      out, err := os.ReadFile(
        path.Join(*dir, c.Name+".out"))
      if err != nil {
        return fmt.Errorf(
          "read %s.out: %v", c.Name, err)
      }
      req.Messages = append(req.Messages,
        api.ChatMessage{
          Role: "assistant",
          Content: []api.ContentBlock{
            {Type: "text", Text: string(out)},
          },
        })
    } else {
      copy, err := copyRequest(req)
      if err != nil {
        return err
      }
      cases[i].Request = copy
      req.Messages = req.Messages[:len(req.Messages)-1]
    }
  }

  if *dryRun {
    log.Printf("dry run; not sending requests")
    return nil
  }

  // TODO: parallelize.
  for _, c := range cases {
    if c.Request == nil {
      continue
    }
    if *debug {
      log.Printf("Case %s: sending request:", c.Name)
      enc := json.NewEncoder(os.Stderr)
      enc.SetIndent("", "  ")
      enc.Encode(c.Request)
    }
    res, err := client.Messages(*c.Request)
    if err != nil {
      log.Printf(
        "Case %s: request failed: %v", c.Name, err)
      continue
    }
    if len(res.Content) != 1 {
      log.Printf(
        "Case %s: empty response", c.Name)
      continue
    }
    c0 := res.Content[0]
    if c0.Type != "text" {
      log.Printf(
        "Case %s: Content[0].Type = %s",
        c.Name, c0.Type)
      continue
    }
    out := path.Join(*dir, c.Name+".out")
    err = os.WriteFile(
      out, append([]byte(c0.Text), '\n'), 0644)
    if err != nil {
      log.Printf(
        "Case %s: failed to write %s.out: %v",
        c.Name, c.Name, err)
      continue
    }
    log.Printf("Case %s: wrote %s.out", c.Name, c.Name)
  }
  return nil
}

// Case is one input or one input-plus-example file group in
// the directory exemplary scans.
type Case struct {
  Name    string
  Before  []string // file extensions present: txt, json, jpg, ...
  After   string   // "out" if a paired .out file exists
  Request *api.MessagesRequest
}

// scanForCases walks dir and groups files into Cases. Files
// named system_prompt.txt are skipped (consumed separately as
// the request's System field). Cases with an .out file sort
// before cases without, so the example-prefix building loop
// can stop at the first non-example.
func scanForCases(dir string) ([]Case, error) {
  entries, err := os.ReadDir(dir)
  if err != nil {
    return nil, fmt.Errorf("read dir %s: %v", dir, err)
  }

  before := make(map[string][]string)
  after := make(map[string]string)
  for _, entry := range entries {
    if entry.IsDir() {
      continue
    }
    if entry.Name() == "system_prompt.txt" {
      continue
    }
    chunks := strings.Split(entry.Name(), ".")
    if len(chunks) < 2 {
      continue
    }
    name := strings.Join(
      chunks[0:len(chunks)-1], ".")
    suffix := chunks[len(chunks)-1]
    switch suffix {
    case "txt", "json", "jpg", "jpeg", "png", "webp":
      before[name] = append(before[name], suffix)
    case "out":
      after[name] = suffix
    }
  }

  var cases []Case
  for name, b := range before {
    sort.Slice(b, func(i, j int) bool {
      if b[i] == "txt" && b[j] != "txt" {
        return true
      }
      if b[i] != "txt" && b[j] == "txt" {
        return false
      }
      return b[i] < b[j]
    })
    cases = append(cases, Case{
      Name:   name,
      Before: b,
      After:  after[name],
    })
  }

  sort.Slice(cases, func(i, j int) bool {
    a, b := cases[i], cases[j]
    if a.After != "" && b.After == "" {
      return true
    }
    if a.After == "" && b.After != "" {
      return false
    }
    return a.Name < b.Name
  })

  numExamples := 0
  for ; numExamples < len(cases); numExamples++ {
    if cases[numExamples].After == "" {
      break
    }
  }
  log.Printf(
    "%s:\n  num_examples = %d\n  num_inputs   = %d",
    dir, numExamples, len(cases)-numExamples)

  return cases, nil
}

// copyRequest deep-copies via JSON marshal/unmarshal so each
// snapshot taken inside the build loop is independent of later
// mutations.
func copyRequest(
  req api.MessagesRequest,
) (*api.MessagesRequest, error) {
  buf, err := json.Marshal(req)
  if err != nil {
    return nil, fmt.Errorf("copy request: %v", err)
  }
  out := &api.MessagesRequest{}
  if err := json.Unmarshal(buf, out); err != nil {
    return nil, fmt.Errorf("copy request: %v", err)
  }
  return out, nil
}