code.oscarkilo.com/okg/pr.go

.gitignore
README.md
auth.go
client.go
config.go
go.mod
main.go
okg_test.go
pr.go
repo.go
package main

import "encoding/json"
import "fmt"
import "os"
import "os/exec"
import "strconv"
import "text/tabwriter"
import "time"

type PR struct {
  Number   int       `json:"number"`
  Title    string    `json:"title"`
  Body     string    `json:"body"`
  State    string    `json:"state"`
  Merged   bool      `json:"merged"`
  Head     string    `json:"head"`
  Base     string    `json:"base"`
  Author   string    `json:"author"`
  Created  time.Time `json:"created"`
  Updated  time.Time `json:"updated"`
  MergedBy string    `json:"merged_by,omitempty"`
  MergedAt time.Time `json:"merged_at,omitempty"`
}

type Comment struct {
  ID      int       `json:"id"`
  Author  string    `json:"author"`
  Body    string    `json:"body"`
  Verdict string    `json:"verdict,omitempty"`
  File    string    `json:"file,omitempty"`
  Line    int       `json:"line,omitempty"`
  Created time.Time `json:"created"`
}

func runPR(args []string) error {
  if len(args) == 0 {
    return fmt.Errorf(
      "usage: okg pr <list|create|view|diff" +
        "|comment|merge|close|reopen>")
  }
  switch args[0] {
  case "list":
    return runPRList(args[1:])
  case "create":
    return runPRCreate(args[1:])
  case "view":
    return runPRView(args[1:])
  case "diff":
    return runPRDiff(args[1:])
  case "comment":
    return runPRComment(args[1:])
  case "merge":
    return runPRMerge(args[1:])
  case "close":
    return runPRClose(args[1:])
  case "reopen":
    return runPRReopen(args[1:])
  default:
    return fmt.Errorf("unknown pr command: %s", args[0])
  }
}

// parseFlags extracts --repo and --json from args,
// returns remaining positional args.
type prFlags struct {
  repo   string
  asJSON bool
}

func parsePRFlags(args []string) (
  *prFlags, []string, error,
) {
  f := &prFlags{}
  var rest []string
  for i := 0; i < len(args); i++ {
    switch args[i] {
    case "--repo":
      i++
      if i >= len(args) {
        return nil, nil, fmt.Errorf(
          "--repo requires a value")
      }
      f.repo = args[i]
    case "--json":
      f.asJSON = true
    default:
      rest = append(rest, args[i])
    }
  }
  return f, rest, nil
}

func setupClient(
  flagRepo string,
) (*Client, string, error) {
  cfg, err := loadConfig()
  if err != nil {
    return nil, "", err
  }
  repo, err := resolveRepo(flagRepo)
  if err != nil {
    return nil, "", err
  }
  return newClient(cfg), repo, nil
}

func outputJSON(v interface{}) error {
  enc := json.NewEncoder(os.Stdout)
  enc.SetIndent("", "  ")
  return enc.Encode(v)
}

func age(t time.Time) string {
  d := time.Since(t)
  switch {
  case d < time.Minute:
    return "just now"
  case d < time.Hour:
    return fmt.Sprintf("%dm", int(d.Minutes()))
  case d < 24*time.Hour:
    return fmt.Sprintf("%dh", int(d.Hours()))
  default:
    return fmt.Sprintf("%dd", int(d.Hours()/24))
  }
}

// --- pr list ---

func runPRList(args []string) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }

  state := "open"
  for i := 0; i < len(rest); i++ {
    if rest[i] == "--state" {
      i++
      if i >= len(rest) {
        return fmt.Errorf("--state requires a value")
      }
      state = rest[i]
    }
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  path := fmt.Sprintf("/%s/prs?state=%s", repo, state)
  var prs []PR
  if err := cl.getJSON(path, &prs); err != nil {
    return err
  }

  if f.asJSON {
    return outputJSON(prs)
  }

  tw := tabwriter.NewWriter(
    os.Stdout, 0, 4, 2, ' ', 0)
  fmt.Fprintln(tw, "#\tTITLE\tAUTHOR\tHEAD\tBASE\tAGE")
  for _, p := range prs {
    fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%s\t%s\n",
      p.Number, p.Title, p.Author,
      p.Head, p.Base, age(p.Created))
  }
  return tw.Flush()
}

// --- pr view ---

func runPRView(args []string) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }
  if len(rest) < 1 {
    return fmt.Errorf("usage: okg pr view NUMBER")
  }
  num, err := strconv.Atoi(rest[0])
  if err != nil {
    return fmt.Errorf("invalid PR number: %s", rest[0])
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  var p PR
  path := fmt.Sprintf("/%s/pr/%d", repo, num)
  if err := cl.getJSON(path, &p); err != nil {
    return err
  }

  var comments []Comment
  cpath := fmt.Sprintf("/%s/pr/%d/comments", repo, num)
  if err := cl.getJSON(cpath, &comments); err != nil {
    return err
  }

  if f.asJSON {
    return outputJSON(map[string]interface{}{
      "pr":       p,
      "comments": comments,
    })
  }

  // Header.
  state_str := p.State
  if p.Merged {
    state_str = "merged"
  }
  fmt.Printf("#%d %s (%s)\n", p.Number, p.Title, state_str)
  fmt.Printf("  %s wants to merge %s into %s\n",
    p.Author, p.Head, p.Base)
  fmt.Printf("  Created %s\n", p.Created.Format(time.RFC3339))
  if p.Body != "" {
    fmt.Printf("\n%s\n", p.Body)
  }

  // Comments.
  if len(comments) > 0 {
    fmt.Printf("\n--- Comments ---\n")
    for _, c := range comments {
      verdict := ""
      if c.Verdict != "" {
        verdict = fmt.Sprintf(" [%s]", c.Verdict)
      }
      fmt.Printf("\n@%s%s (%s):\n%s\n",
        c.Author, verdict,
        c.Created.Format(time.RFC3339), c.Body)
    }
  }
  return nil
}

// --- pr diff ---

func runPRDiff(args []string) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }
  if len(rest) < 1 {
    return fmt.Errorf("usage: okg pr diff NUMBER")
  }
  num, err := strconv.Atoi(rest[0])
  if err != nil {
    return fmt.Errorf("invalid PR number: %s", rest[0])
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  var p PR
  path := fmt.Sprintf("/%s/pr/%d", repo, num)
  if err := cl.getJSON(path, &p); err != nil {
    return err
  }

  // Run git diff locally.
  cmd := exec.Command(
    "git", "diff", p.Base+"..."+p.Head)
  cmd.Stdout = os.Stdout
  cmd.Stderr = os.Stderr
  return cmd.Run()
}

// --- pr create ---

func runPRCreate(args []string) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }

  head := ""
  base := "master"
  title := ""
  body := ""
  for i := 0; i < len(rest); i++ {
    switch rest[i] {
    case "--head":
      i++
      if i >= len(rest) {
        return fmt.Errorf("--head requires a value")
      }
      head = rest[i]
    case "--base":
      i++
      if i >= len(rest) {
        return fmt.Errorf("--base requires a value")
      }
      base = rest[i]
    case "--title":
      i++
      if i >= len(rest) {
        return fmt.Errorf("--title requires a value")
      }
      title = rest[i]
    case "--body":
      i++
      if i >= len(rest) {
        return fmt.Errorf("--body requires a value")
      }
      body = rest[i]
    default:
      return fmt.Errorf("unknown flag: %s", rest[i])
    }
  }

  if head == "" {
    return fmt.Errorf("--head is required")
  }
  if title == "" {
    return fmt.Errorf("--title is required")
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  payload := map[string]string{
    "head":  head,
    "base":  base,
    "title": title,
    "body":  body,
  }
  var p PR
  path := fmt.Sprintf("/%s/prs", repo)
  if err := cl.postJSON(path, payload, &p); err != nil {
    return err
  }

  if f.asJSON {
    return outputJSON(p)
  }

  fmt.Printf("Created PR #%d: %s\n", p.Number, p.Title)
  fmt.Printf("  %s -> %s\n", p.Head, p.Base)
  return nil
}

// --- pr comment ---

func runPRComment(args []string) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }
  if len(rest) < 1 {
    return fmt.Errorf(
      "usage: okg pr comment NUMBER --body BODY")
  }
  num, err := strconv.Atoi(rest[0])
  if err != nil {
    return fmt.Errorf("invalid PR number: %s", rest[0])
  }
  rest = rest[1:]

  body := ""
  verdict := ""
  for i := 0; i < len(rest); i++ {
    switch rest[i] {
    case "--body":
      i++
      if i >= len(rest) {
        return fmt.Errorf("--body requires a value")
      }
      body = rest[i]
    case "--approve":
      verdict = "approve"
    case "--request-changes":
      verdict = "request_changes"
    default:
      return fmt.Errorf("unknown flag: %s", rest[i])
    }
  }
  if body == "" {
    return fmt.Errorf("--body is required")
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  payload := map[string]string{
    "body":    body,
    "verdict": verdict,
  }
  var c Comment
  path := fmt.Sprintf("/%s/pr/%d/comments", repo, num)
  if err := cl.postJSON(path, payload, &c); err != nil {
    return err
  }

  if f.asJSON {
    return outputJSON(c)
  }

  fmt.Printf("Comment #%d by @%s", c.ID, c.Author)
  if c.Verdict != "" {
    fmt.Printf(" [%s]", c.Verdict)
  }
  fmt.Printf(":\n%s\n", c.Body)
  return nil
}

// --- pr merge ---

func runPRMerge(args []string) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }
  if len(rest) < 1 {
    return fmt.Errorf("usage: okg pr merge NUMBER")
  }
  num, err := strconv.Atoi(rest[0])
  if err != nil {
    return fmt.Errorf("invalid PR number: %s", rest[0])
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  var p PR
  path := fmt.Sprintf("/%s/pr/%d/merge", repo, num)
  if err := cl.postJSON(path, nil, &p); err != nil {
    return err
  }

  if f.asJSON {
    return outputJSON(p)
  }

  fmt.Printf("PR #%d merged by @%s\n", p.Number, p.MergedBy)
  return nil
}

// --- pr close ---

func runPRClose(args []string) error {
  return runPRStateChange("closed", args)
}

// --- pr reopen ---

func runPRReopen(args []string) error {
  return runPRStateChange("open", args)
}

func runPRStateChange(
  new_state string, args []string,
) error {
  f, rest, err := parsePRFlags(args)
  if err != nil {
    return err
  }
  if len(rest) < 1 {
    return fmt.Errorf(
      "usage: okg pr close|reopen NUMBER")
  }
  num, err := strconv.Atoi(rest[0])
  if err != nil {
    return fmt.Errorf("invalid PR number: %s", rest[0])
  }

  cl, repo, err := setupClient(f.repo)
  if err != nil {
    return err
  }

  payload := map[string]string{"state": new_state}
  var p PR
  path := fmt.Sprintf("/%s/pr/%d", repo, num)
  if err := cl.patchJSON(path, payload, &p); err != nil {
    return err
  }

  if f.asJSON {
    return outputJSON(p)
  }

  fmt.Printf("PR #%d is now %s\n", p.Number, p.State)
  return nil
}