package llm import ( "context" "fmt" "log" "os" "strings" openai "github.com/sashabaranov/go-openai" ) var providers = map[string]string{ "deepseek": "https://api.deepseek.com/v1", "kimi": "https://api.moonshot.cn/v1", "ollama": "http://localhost:11434/v1", "openai": "https://api.openai.com/v1", } var defaultModels = map[string]string{ "deepseek": "deepseek-chat", "kimi": "moonshot-v1-8k", "ollama": "qwen2.5", "openai": "gpt-4o", } type Client struct { c *openai.Client model string } // Tool 是对外暴露的工具定义 type Tool = openai.Tool type ToolCall = openai.ToolCall type FunctionDefinition = openai.FunctionDefinition var ToolTypeFunction = openai.ToolTypeFunction func New(provider, model, baseURL, apiKeyEnv string) (*Client, error) { if baseURL == "" { var ok bool baseURL, ok = providers[provider] if !ok { baseURL = providers["deepseek"] } } if model == "" { model = defaultModels[provider] if model == "" { model = "deepseek-chat" } } apiKey := os.Getenv(apiKeyEnv) if apiKey == "" { apiKey = "ollama" // ollama doesn't need a real key } cfg := openai.DefaultConfig(apiKey) cfg.BaseURL = baseURL return &Client{c: openai.NewClientWithConfig(cfg), model: model}, nil } type Message = openai.ChatCompletionMessage func NewMsg(role, content string) Message { return Message{Role: role, Content: content} } func NewToolResultMsg(toolCallID, content string) Message { return Message{Role: "tool", Content: content, ToolCallID: toolCallID} } // Usage 记录 token 用量 type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } // StreamResult 包含流式响应的结果 type StreamResult struct { Content string // 文本回复 ToolCalls []ToolCall // 工具调用请求 Usage Usage // token 用量 } // Stream calls the LLM and streams tokens to the callback. Returns full response. func (c *Client) Stream(ctx context.Context, msgs []Message, onToken func(string)) (*StreamResult, error) { return c.StreamWithTools(ctx, msgs, nil, onToken) } // StreamWithTools 支持 tool calling 的流式调用 func (c *Client) StreamWithTools(ctx context.Context, msgs []Message, tools []Tool, onToken func(string)) (*StreamResult, error) { req := openai.ChatCompletionRequest{ Model: c.model, Messages: msgs, Stream: true, StreamOptions: &openai.StreamOptions{ IncludeUsage: true, }, } if len(tools) > 0 { req.Tools = tools } stream, err := c.c.CreateChatCompletionStream(ctx, req) if err != nil { return nil, fmt.Errorf("llm stream: %w", err) } defer stream.Close() var content strings.Builder toolCallMap := make(map[int]*ToolCall) // index -> ToolCall var usage Usage for { resp, err := stream.Recv() if err != nil { break } // 捕获 usage(最后一个 chunk 包含) if resp.Usage != nil { usage = Usage{ PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, } } if len(resp.Choices) == 0 { continue } delta := resp.Choices[0].Delta // 处理文本内容 if delta.Content != "" { content.WriteString(delta.Content) if onToken != nil { onToken(delta.Content) } } // 处理 tool calls(流式中分块到达) for _, tc := range delta.ToolCalls { idx := 0 if tc.Index != nil { idx = *tc.Index } existing, ok := toolCallMap[idx] if !ok { toolCallMap[idx] = &ToolCall{ ID: tc.ID, Type: tc.Type, Function: openai.FunctionCall{ Name: tc.Function.Name, Arguments: tc.Function.Arguments, }, } } else { if tc.ID != "" { existing.ID = tc.ID } if tc.Function.Name != "" { existing.Function.Name = tc.Function.Name } existing.Function.Arguments += tc.Function.Arguments } } } var toolCalls []ToolCall for i := 0; i < len(toolCallMap); i++ { if tc, ok := toolCallMap[i]; ok { toolCalls = append(toolCalls, *tc) } } log.Printf("[llm] usage: prompt=%d completion=%d total=%d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens) return &StreamResult{ Content: content.String(), ToolCalls: toolCalls, Usage: usage, }, nil }