352 lines
7.3 KiB
Go

package tools
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"github.com/sdaduanbilei/agent-team/internal/llm"
)
type Executor struct {
workspaceDir string
}
func NewExecutor(workspaceDir string) *Executor {
return &Executor{workspaceDir: workspaceDir}
}
func (e *Executor) safePath(filename string) (string, error) {
abs, err := filepath.Abs(filepath.Join(e.workspaceDir, filename))
if err != nil {
return "", err
}
if !strings.HasPrefix(abs, e.workspaceDir) {
return "", fmt.Errorf("path traversal detected: %s", filename)
}
return abs, nil
}
func (e *Executor) Execute(toolCall llm.ToolCall) (string, error) {
switch toolCall.Function.Name {
case "glob":
return e.glob(toolCall.Function.Arguments)
case "grep":
return e.grep(toolCall.Function.Arguments)
case "read_file":
return e.readFile(toolCall.Function.Arguments)
case "edit_file":
return e.editFile(toolCall.Function.Arguments)
case "write_file":
return e.writeFile(toolCall.Function.Arguments)
case "list_workspace":
return e.listWorkspace()
case "git_status":
return e.gitStatus()
case "git_diff":
return e.gitDiff(toolCall.Function.Arguments)
case "git_commit":
return e.gitCommit(toolCall.Function.Arguments)
default:
return "", nil
}
}
type GlobArgs struct {
Pattern string `json:"pattern"`
}
func (e *Executor) glob(args string) (string, error) {
var a GlobArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
pattern := filepath.Join(e.workspaceDir, a.Pattern)
files, err := filepath.Glob(pattern)
if err != nil {
return "", err
}
var result []string
for _, f := range files {
rel, _ := filepath.Rel(e.workspaceDir, f)
result = append(result, rel)
}
if len(result) == 0 {
return "未找到匹配的文件", nil
}
return strings.Join(result, "\n"), nil
}
type GrepArgs struct {
Pattern string `json:"pattern"`
Path string `json:"path"`
Include string `json:"include"`
}
func (e *Executor) grep(args string) (string, error) {
var a GrepArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
searchDir := e.workspaceDir
if a.Path != "" {
searchDir = filepath.Join(e.workspaceDir, a.Path)
}
re, err := regexp.Compile(a.Pattern)
if err != nil {
return "", err
}
var results []string
err = filepath.Walk(searchDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
return nil
}
if a.Include != "" {
matched, _ := filepath.Match(a.Include, info.Name())
if !matched {
return nil
}
}
content, err := os.ReadFile(path)
if err != nil {
return nil
}
lines := strings.Split(string(content), "\n")
for i, line := range lines {
if re.MatchString(line) {
rel, _ := filepath.Rel(e.workspaceDir, path)
results = append(results, fmt.Sprintf("%s:%d: %s", rel, i+1, line))
}
}
return nil
})
if err != nil {
return "", err
}
if len(results) == 0 {
return "未找到匹配的内容", nil
}
if len(results) > 100 {
results = results[:100]
results = append(results, "... (还有更多结果)")
}
return strings.Join(results, "\n"), nil
}
type ReadFileArgs struct {
Filename string `json:"filename"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}
func (e *Executor) readFile(args string) (string, error) {
var a ReadFileArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
if a.Limit == 0 {
a.Limit = 100
}
fpath, err := e.safePath(a.Filename)
if err != nil {
return "", err
}
content, err := os.ReadFile(fpath)
if err != nil {
return "", err
}
lines := strings.Split(string(content), "\n")
if a.Offset >= len(lines) {
return "起始位置超出文件行数", nil
}
end := a.Offset + a.Limit
if end > len(lines) {
end = len(lines)
}
result := strings.Join(lines[a.Offset:end], "\n")
if end < len(lines) {
result += fmt.Sprintf("\n\n... (共 %d 行,当前显示 %d-%d)", len(lines), a.Offset+1, end)
}
return result, nil
}
type EditFileArgs struct {
Filename string `json:"filename"`
OldString string `json:"old_string"`
NewString string `json:"new_string"`
}
func (e *Executor) editFile(args string) (string, error) {
var a EditFileArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
fpath, err := e.safePath(a.Filename)
if err != nil {
return "", err
}
original, err := os.ReadFile(fpath)
if err != nil {
return "", err
}
if !strings.Contains(string(original), a.OldString) {
return "", fmt.Errorf("文件中未找到要替换的内容,请使用更精确的匹配字符串或使用 write_file 完整覆盖文件")
}
updated := strings.Replace(string(original), a.OldString, a.NewString, 1)
if err := os.WriteFile(fpath, []byte(updated), 0644); err != nil {
return "", err
}
return fmt.Sprintf("已更新文件: %s", a.Filename), nil
}
type WriteFileArgs struct {
Filename string `json:"filename"`
Content string `json:"content"`
}
func (e *Executor) writeFile(args string) (string, error) {
var a WriteFileArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
fpath, err := e.safePath(a.Filename)
if err != nil {
return "", err
}
os.MkdirAll(filepath.Dir(fpath), 0755)
exists := ""
if _, err := os.Stat(fpath); err == nil {
exists = " (已存在,已覆盖)"
}
if err := os.WriteFile(fpath, []byte(a.Content), 0644); err != nil {
return "", err
}
return fmt.Sprintf("已写入文件: %s%s", a.Filename, exists), nil
}
func (e *Executor) listWorkspace() (string, error) {
var files []string
err := filepath.Walk(e.workspaceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
return nil
}
rel, _ := filepath.Rel(e.workspaceDir, path)
if !strings.HasPrefix(rel, ".") {
files = append(files, rel)
}
return nil
})
if err != nil {
return "", err
}
if len(files) == 0 {
return "工作区为空", nil
}
return strings.Join(files, "\n"), nil
}
func (e *Executor) gitStatus() (string, error) {
cmd := exec.Command("git", "status", "--porcelain")
cmd.Dir = e.workspaceDir
out, err := cmd.Output()
if err != nil {
return "", err
}
if len(out) == 0 {
return "工作区干净,无待提交更改", nil
}
return string(out), nil
}
type GitDiffArgs struct {
Filename string `json:"filename"`
}
func (e *Executor) gitDiff(args string) (string, error) {
var a GitDiffArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
var cmd *exec.Cmd
if a.Filename != "" {
cmd = exec.Command("git", "diff", a.Filename)
} else {
cmd = exec.Command("git", "diff")
}
cmd.Dir = e.workspaceDir
out, err := cmd.Output()
if err != nil {
return "", err
}
if len(out) == 0 {
return "无更改", nil
}
return string(out), nil
}
type GitCommitArgs struct {
Message string `json:"message"`
}
func (e *Executor) gitCommit(args string) (string, error) {
var a GitCommitArgs
if err := json.Unmarshal([]byte(args), &a); err != nil {
return "", err
}
cmd := exec.Command("git", "add", "-A")
cmd.Dir = e.workspaceDir
if _, err := cmd.Output(); err != nil {
return "", err
}
cmd = exec.Command("git", "commit", "-m", a.Message)
out, err := cmd.Output()
if err != nil {
return "", err
}
return string(out), nil
}