package main
import "encoding/json"
import "fmt"
import "os"
import "os/exec"
import "strconv"
import "text/tabwriter"
import "time"
type PR struct {
Number int `json:"number"`
Title string `json:"title"`
Body string `json:"body"`
State string `json:"state"`
Merged bool `json:"merged"`
Head string `json:"head"`
Base string `json:"base"`
Author string `json:"author"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
MergedBy string `json:"merged_by,omitempty"`
MergedAt time.Time `json:"merged_at,omitempty"`
}
type Comment struct {
ID int `json:"id"`
Author string `json:"author"`
Body string `json:"body"`
Verdict string `json:"verdict,omitempty"`
File string `json:"file,omitempty"`
Line int `json:"line,omitempty"`
Created time.Time `json:"created"`
}
func runPR(args []string) error {
if len(args) == 0 {
return fmt.Errorf(
"usage: okg pr <list|create|view|diff" +
"|comment|merge|close|reopen>")
}
switch args[0] {
case "list":
return runPRList(args[1:])
case "create":
return runPRCreate(args[1:])
case "view":
return runPRView(args[1:])
case "diff":
return runPRDiff(args[1:])
case "comment":
return runPRComment(args[1:])
case "merge":
return runPRMerge(args[1:])
case "close":
return runPRClose(args[1:])
case "reopen":
return runPRReopen(args[1:])
default:
return fmt.Errorf("unknown pr command: %s", args[0])
}
}
// parseFlags extracts --repo and --json from args,
// returns remaining positional args.
type prFlags struct {
repo string
asJSON bool
}
func parsePRFlags(args []string) (
*prFlags, []string, error,
) {
f := &prFlags{}
var rest []string
for i := 0; i < len(args); i++ {
switch args[i] {
case "--repo":
i++
if i >= len(args) {
return nil, nil, fmt.Errorf(
"--repo requires a value")
}
f.repo = args[i]
case "--json":
f.asJSON = true
default:
rest = append(rest, args[i])
}
}
return f, rest, nil
}
func setupClient(
flagRepo string,
) (*Client, string, error) {
cfg, err := loadConfig()
if err != nil {
return nil, "", err
}
repo, err := resolveRepo(flagRepo)
if err != nil {
return nil, "", err
}
return newClient(cfg), repo, nil
}
func outputJSON(v interface{}) error {
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
return enc.Encode(v)
}
func age(t time.Time) string {
d := time.Since(t)
switch {
case d < time.Minute:
return "just now"
case d < time.Hour:
return fmt.Sprintf("%dm", int(d.Minutes()))
case d < 24*time.Hour:
return fmt.Sprintf("%dh", int(d.Hours()))
default:
return fmt.Sprintf("%dd", int(d.Hours()/24))
}
}
// --- pr list ---
func runPRList(args []string) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
state := "open"
for i := 0; i < len(rest); i++ {
if rest[i] == "--state" {
i++
if i >= len(rest) {
return fmt.Errorf("--state requires a value")
}
state = rest[i]
}
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
path := fmt.Sprintf("/%s/prs?state=%s", repo, state)
var prs []PR
if err := cl.getJSON(path, &prs); err != nil {
return err
}
if f.asJSON {
return outputJSON(prs)
}
tw := tabwriter.NewWriter(
os.Stdout, 0, 4, 2, ' ', 0)
fmt.Fprintln(tw, "#\tTITLE\tAUTHOR\tHEAD\tBASE\tAGE")
for _, p := range prs {
fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%s\t%s\n",
p.Number, p.Title, p.Author,
p.Head, p.Base, age(p.Created))
}
return tw.Flush()
}
// --- pr view ---
func runPRView(args []string) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
if len(rest) < 1 {
return fmt.Errorf("usage: okg pr view NUMBER")
}
num, err := strconv.Atoi(rest[0])
if err != nil {
return fmt.Errorf("invalid PR number: %s", rest[0])
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
var p PR
path := fmt.Sprintf("/%s/pr/%d", repo, num)
if err := cl.getJSON(path, &p); err != nil {
return err
}
var comments []Comment
cpath := fmt.Sprintf("/%s/pr/%d/comments", repo, num)
if err := cl.getJSON(cpath, &comments); err != nil {
return err
}
if f.asJSON {
return outputJSON(map[string]interface{}{
"pr": p,
"comments": comments,
})
}
// Header.
state_str := p.State
if p.Merged {
state_str = "merged"
}
fmt.Printf("#%d %s (%s)\n", p.Number, p.Title, state_str)
fmt.Printf(" %s wants to merge %s into %s\n",
p.Author, p.Head, p.Base)
fmt.Printf(" Created %s\n", p.Created.Format(time.RFC3339))
if p.Body != "" {
fmt.Printf("\n%s\n", p.Body)
}
// Comments.
if len(comments) > 0 {
fmt.Printf("\n--- Comments ---\n")
for _, c := range comments {
verdict := ""
if c.Verdict != "" {
verdict = fmt.Sprintf(" [%s]", c.Verdict)
}
fmt.Printf("\n@%s%s (%s):\n%s\n",
c.Author, verdict,
c.Created.Format(time.RFC3339), c.Body)
}
}
return nil
}
// --- pr diff ---
func runPRDiff(args []string) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
if len(rest) < 1 {
return fmt.Errorf("usage: okg pr diff NUMBER")
}
num, err := strconv.Atoi(rest[0])
if err != nil {
return fmt.Errorf("invalid PR number: %s", rest[0])
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
var p PR
path := fmt.Sprintf("/%s/pr/%d", repo, num)
if err := cl.getJSON(path, &p); err != nil {
return err
}
// Run git diff locally.
cmd := exec.Command(
"git", "diff", p.Base+"..."+p.Head)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// --- pr create ---
func runPRCreate(args []string) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
head := ""
base := "master"
title := ""
body := ""
for i := 0; i < len(rest); i++ {
switch rest[i] {
case "--head":
i++
if i >= len(rest) {
return fmt.Errorf("--head requires a value")
}
head = rest[i]
case "--base":
i++
if i >= len(rest) {
return fmt.Errorf("--base requires a value")
}
base = rest[i]
case "--title":
i++
if i >= len(rest) {
return fmt.Errorf("--title requires a value")
}
title = rest[i]
case "--body":
i++
if i >= len(rest) {
return fmt.Errorf("--body requires a value")
}
body = rest[i]
default:
return fmt.Errorf("unknown flag: %s", rest[i])
}
}
if head == "" {
return fmt.Errorf("--head is required")
}
if title == "" {
return fmt.Errorf("--title is required")
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
payload := map[string]string{
"head": head,
"base": base,
"title": title,
"body": body,
}
var p PR
path := fmt.Sprintf("/%s/prs", repo)
if err := cl.postJSON(path, payload, &p); err != nil {
return err
}
if f.asJSON {
return outputJSON(p)
}
fmt.Printf("Created PR #%d: %s\n", p.Number, p.Title)
fmt.Printf(" %s -> %s\n", p.Head, p.Base)
return nil
}
// --- pr comment ---
func runPRComment(args []string) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
if len(rest) < 1 {
return fmt.Errorf(
"usage: okg pr comment NUMBER --body BODY")
}
num, err := strconv.Atoi(rest[0])
if err != nil {
return fmt.Errorf("invalid PR number: %s", rest[0])
}
rest = rest[1:]
body := ""
verdict := ""
for i := 0; i < len(rest); i++ {
switch rest[i] {
case "--body":
i++
if i >= len(rest) {
return fmt.Errorf("--body requires a value")
}
body = rest[i]
case "--approve":
verdict = "approve"
case "--request-changes":
verdict = "request_changes"
default:
return fmt.Errorf("unknown flag: %s", rest[i])
}
}
if body == "" {
return fmt.Errorf("--body is required")
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
payload := map[string]string{
"body": body,
"verdict": verdict,
}
var c Comment
path := fmt.Sprintf("/%s/pr/%d/comments", repo, num)
if err := cl.postJSON(path, payload, &c); err != nil {
return err
}
if f.asJSON {
return outputJSON(c)
}
fmt.Printf("Comment #%d by @%s", c.ID, c.Author)
if c.Verdict != "" {
fmt.Printf(" [%s]", c.Verdict)
}
fmt.Printf(":\n%s\n", c.Body)
return nil
}
// --- pr merge ---
func runPRMerge(args []string) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
if len(rest) < 1 {
return fmt.Errorf("usage: okg pr merge NUMBER")
}
num, err := strconv.Atoi(rest[0])
if err != nil {
return fmt.Errorf("invalid PR number: %s", rest[0])
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
var p PR
path := fmt.Sprintf("/%s/pr/%d/merge", repo, num)
if err := cl.postJSON(path, nil, &p); err != nil {
return err
}
if f.asJSON {
return outputJSON(p)
}
fmt.Printf("PR #%d merged by @%s\n", p.Number, p.MergedBy)
return nil
}
// --- pr close ---
func runPRClose(args []string) error {
return runPRStateChange("closed", args)
}
// --- pr reopen ---
func runPRReopen(args []string) error {
return runPRStateChange("open", args)
}
func runPRStateChange(
new_state string, args []string,
) error {
f, rest, err := parsePRFlags(args)
if err != nil {
return err
}
if len(rest) < 1 {
return fmt.Errorf(
"usage: okg pr close|reopen NUMBER")
}
num, err := strconv.Atoi(rest[0])
if err != nil {
return fmt.Errorf("invalid PR number: %s", rest[0])
}
cl, repo, err := setupClient(f.repo)
if err != nil {
return err
}
payload := map[string]string{"state": new_state}
var p PR
path := fmt.Sprintf("/%s/pr/%d", repo, num)
if err := cl.patchJSON(path, payload, &p); err != nil {
return err
}
if f.asJSON {
return outputJSON(p)
}
fmt.Printf("PR #%d is now %s\n", p.Number, p.State)
return nil
}