1
package api
2
3
import "bufio"
4
import "bytes"
5
import "context"
6
import "encoding/json"
7
import "fmt"
8
import "io/ioutil"
9
import "net/http"
10
import "strings"
11
12
// MessagesStream calls the streaming endpoint and
13
// delivers text deltas via cb. Set req.Model to one
14
// of the Klex LLM function names.
15
func (c *Client) MessagesStream(
16
ctx context.Context,
17
req MessagesRequest,
18
cb func(string),
19
) error {
20
f := req.Model
21
req.Model = ""
22
if f == "" {
23
return fmt.Errorf(
24
"MessagesRequest.Model is empty")
25
}
26
27
type streamReq struct {
28
FName string `json:"f_name"`
29
MessagesRequest
30
}
31
body, err := json.Marshal(
32
streamReq{FName: f, MessagesRequest: req})
33
if err != nil {
34
return fmt.Errorf("marshal: %v", err)
35
}
36
37
r, err := http.NewRequestWithContext(
38
ctx, "POST",
39
c.KlexURL+"/chat/stream",
40
bytes.NewReader(body))
41
if err != nil {
42
return fmt.Errorf("new request: %v", err)
43
}
44
r.Header.Set(
45
"Authorization", "Bearer "+c.APIKey)
46
r.Header.Set(
47
"Content-Type", "application/json")
48
49
resp, err := http.DefaultClient.Do(r)
50
if err != nil {
51
return fmt.Errorf("http: %v", err)
52
}
53
defer resp.Body.Close()
54
55
if resp.StatusCode != 200 {
56
b, _ := ioutil.ReadAll(resp.Body)
57
return fmt.Errorf(
58
"status %d: %s", resp.StatusCode, b)
59
}
60
61
return parseStreamSSE(resp.Body, cb)
62
}
63
64
// parseStreamSSE reads SSE events from the funky
65
// /chat/stream endpoint. Events are JSON objects:
66
//
67
// {"text":"token"} — text delta
68
// {"done":true} — stream complete
69
// {"error":"msg"} — error
70
func parseStreamSSE(
71
r interface{ Read([]byte) (int, error) },
72
cb func(string),
73
) error {
74
scanner := bufio.NewScanner(r)
75
scanner.Buffer(
76
make([]byte, 0, 64*1024), 1024*1024)
77
for scanner.Scan() {
78
line := scanner.Text()
79
if !strings.HasPrefix(line, "data: ") {
80
continue
81
}
82
data := line[6:]
83
var event struct {
84
Text string `json:"text"`
85
Done bool `json:"done"`
86
Error string `json:"error"`
87
}
88
if err := json.Unmarshal(
89
[]byte(data), &event); err != nil {
90
continue
91
}
92
if event.Error != "" {
93
return fmt.Errorf("stream: %s", event.Error)
94
}
95
if event.Done {
96
return nil
97
}
98
if event.Text != "" {
99
cb(event.Text)
100
}
101
}
102
return scanner.Err()
103
}