diff --git a/README.md b/README.md index cf4ae71..a18962d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,48 @@ # gocmdDaemon +这是一个基于 Unix Socket 的守护进程,提供命令注册和执行的功能。 + +## 安装 + +```bash +go get github.com/kingecg/gocmdDaemon +``` + +## 使用示例 + +```go +package main + +import ( + "github.com/kingecg/gocmdDaemon" +) + +func main() { + // 创建守护进程实例 + daemon := &gocmdDaemon.CmdDaemon{ + SocketPath: "/tmp/my.sock", + } + + // 注册命令处理程序 + daemon.RegisterCmd("test", &MyCmdHandler{}) + + // 启动守护进程 + daemon.Listen() +} + +// MyCmdHandler 实现 CmdHandler 接口 +type MyCmdHandler struct{} + +func (h *MyCmdHandler) Handle(conn *gocmdDaemon.CmdConn, req *gocmdDaemon.CmdRequest) error { + return conn.End("Command executed successfully") +} + +func (h *MyCmdHandler) Description() string { + return "A test command handler" +} + +func (h *MyCmdHandler) Usage() string { + return "usage: test" +} +``` + diff --git a/cmd_daemon_test.go b/cmd_daemon_test.go new file mode 100644 index 0000000..ea25e59 --- /dev/null +++ b/cmd_daemon_test.go @@ -0,0 +1,205 @@ +package main + +import ( + "encoding/json" + "fmt" + "net" + "os" + "testing" + "time" +) + +func TestWrite(t *testing.T) { + // 创建一个配对的连接 + conn1, conn2 := net.Pipe() + defer conn1.Close() + defer conn2.Close() + + // 要写入的数据 + data := map[string]string{"test": "data"} + + // 在一个goroutine中执行写操作 + go func() { + err := Write(conn1, data) + if err != nil { + t.Errorf("Write failed: %v", err) + } + }() + + // 读取数据并验证 + var received map[string]string + err := json.NewDecoder(conn2).Decode(&received) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + for k, v := range data { + if received[k] != v { + t.Errorf("Received data does not match. Expected %v:%v, Got %v:%v", k, v, received[k], received[k]) + } + } +} + +func TestRead(t *testing.T) { + // 创建一个配对的连接 + conn1, conn2 := net.Pipe() + defer conn1.Close() + defer conn2.Close() + + // 准备要读取的数据 + data := map[string]string{"test": "data"} + + // 在一个goroutine中执行写操作 + go func() { + encoder := json.NewEncoder(conn1) + err := encoder.Encode(data) + if err != nil { + t.Errorf("Write failed: %v", err) + } + }() + + // 读取数据并验证 + received, err := Read[map[string]string](conn2) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + for k, v := range data { + if (*received)[k] != v { + t.Errorf("Received data does not match. Expected %v:%v, Got %v:%v", k, v, (*received)[k], (*received)[k]) + } + } +} + +func TestCmdDaemonListen(t *testing.T) { + socketPath := "/tmp/test.sock" + daemon := &CmdDaemon{ + SocketPath: socketPath, + cmds: make(map[string]CmdHandler), + } + + // 注册一个简单的命令处理程序 + daemon.RegisterCmd("test", &SimpleCmdHandler{}) + + // 启动守护进程 + go func() { + err := daemon.Listen() + if err != nil { + t.Errorf("Listen failed: %v", err) + } + }() + + // 等待服务器启动 + time.Sleep(1 * time.Second) + + // 模拟客户端连接 + conn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatalf("Client connect failed: %v", err) + } + defer conn.Close() + + // 发送请求 + req := CmdRequest{ + Id: "1", + Cmd: "test", + Args: "arg1", + IsDebug: false, + } + err = Write(conn, req) + if err != nil { + t.Errorf("Write failed: %v", err) + } + + // 读取响应 + resp, err := Read[CmdResponse](conn) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if resp.Data != "Handled test with args: arg1" || resp.Error != "" { + t.Errorf("Unexpected response. Data: %v, Error: %v", resp.Data, resp.Error) + } +} + +func TestCmdDaemonRun(t *testing.T) { + // 这里我们假设有一个正在运行的服务器来处理请求 + // 因此我们需要首先启动一个简单的服务器来测试Run方法 + socketPath := "/tmp/test_run.sock" + daemon := &CmdDaemon{ + SocketPath: socketPath, + cmds: make(map[string]CmdHandler), + } + + // 注册一个简单的命令处理程序 + daemon.RegisterCmd("test", &SimpleCmdHandler{}) + + // 创建一个同步通道来等待守护进程开始监听 + listening := make(chan struct{}) + + // 启动守护进程 + go startTestDaemon(daemon, listening) + + // 等待直到监听开始 + // for循环中检查socket文件是否存在,如果存在退出循环,否则sleepeep 1秒 + // 使用一个错误channel来传递错误 + errChan := make(chan error) + + go func() { + listened := false + for { + + if listened { + break + } + if _, err := os.Stat(socketPath); err == nil { + listened = true + } + time.Sleep(1 * time.Second) + } + + // 设置命令行参数 + os.Args = []string{"cmd", "--debug", "test", "arg1"} + ndaemon := &CmdDaemon{ + SocketPath: socketPath, + } + // 执行Run方法 + err := ndaemon.Run() + if err != nil { + errChan <- err + } else { + t.Log("Run completed successfully") + close(errChan) + } + }() + err := <-errChan + if err != nil { + t.Fatal(err) + } + +} + +// startTestDaemon 启动守护进程并在后台运行 +func startTestDaemon(daemon *CmdDaemon, ready chan struct{}) { + err := daemon.Listen() + if err != nil { + panic(fmt.Sprintf("Listen failed: %v", err)) + } + +} + +// SimpleCmdHandler 是一个简单的CmdHandler实现用于测试 +type SimpleCmdHandler struct{} + +func (h *SimpleCmdHandler) Handle(conn *CmdConn, req *CmdRequest) error { + err := conn.End(fmt.Sprintf("Handled %s with args: %s", req.Cmd, req.Args)) + return err +} + +func (h *SimpleCmdHandler) Description() string { + return "Simple command handler for testing" +} + +func (h *SimpleCmdHandler) Usage() string { + return "Usage: test [args]" +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1667ba2 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.pyer.club/kingecg/gocmdDaemon + +go 1.23.1 + +require github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7790d7c --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/io.go b/io.go new file mode 100644 index 0000000..1f1e16b --- /dev/null +++ b/io.go @@ -0,0 +1,22 @@ +package main + +import ( + "encoding/json" + "net" +) + +func Write[T any](conn net.Conn, v T) error { + return json.NewEncoder(conn).Encode(v) +} + +func Read[T any](conn net.Conn) (*T, error) { + + // 再读报文内容 + var zero T + err := json.NewDecoder(conn).Decode(&zero) + + if err != nil { + return nil, err + } + return &zero, nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..4681b44 --- /dev/null +++ b/main.go @@ -0,0 +1,192 @@ +package main + +import ( + "encoding/json" + "fmt" + "net" + "os" + "strings" + + "github.com/google/uuid" +) + +type CmdDaemon struct { + SocketPath string // unix socket path + cmds map[string]CmdHandler +} +type CmdRequest struct { + Id string `json:"id"` + Cmd string `json:"cmd"` + Args string `json:"args"` + IsDebug bool `json:"debug"` +} +type CmdResponse struct { + Id string `json:"id"` + Data string `json:"data"` + Error string `json:"error"` + Continue bool `json:"continue"` +} + +type CmdConn struct { + net.Conn + Id string +} + +func (c *CmdConn) Write(d string) error { + resp := CmdResponse{ + Id: c.Id, + Data: d, + Continue: true, + } + return Write(c.Conn, resp) +} +func (c *CmdConn) WriteError(err error, isContinue bool) error { + resp := CmdResponse{ + Id: c.Id, + Error: err.Error(), + Continue: isContinue, + } + return Write(c.Conn, resp) +} +func (c *CmdConn) End(d string) error { + resp := CmdResponse{ + Id: c.Id, + Data: d, + Continue: false, + } + return Write(c.Conn, resp) +} + +type CmdHandler interface { + Handle(conn *CmdConn, req *CmdRequest) error + Description() string + Usage() string +} + +func (c *CmdDaemon) Listen() error { + // 删除已存在的 socket 文件(如果存在) + if err := os.Remove(c.SocketPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove existing socket file: %v", err) + } + + // 监听 unix socket + listener, err := net.Listen("unix", c.SocketPath) + if err != nil { + return fmt.Errorf("failed to listen on socket: %v", err) + } + defer listener.Close() + + // 设置 socket 文件权限 + if err := os.Chmod(c.SocketPath, 0777); err != nil { + return fmt.Errorf("failed to set socket file permissions: %v", err) + } + + for { + conn, err := listener.Accept() + if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + // 处理每个连接 + go func(conn net.Conn) { + defer conn.Close() + + req, err := Read[CmdRequest](conn) + if err != nil { + _ = Write(conn, CmdResponse{ + Error: "failed to read request: " + err.Error(), + Continue: false, + }) + return + } + + cmdHandler, ok := c.cmds[req.Cmd] + if !ok { + _ = Write(conn, CmdResponse{ + Error: "unknown command: " + req.Cmd, + Continue: false, + }) + return + } + + // 执行命令处理程序 + cmdConn := &CmdConn{Conn: conn, Id: req.Id} + err = cmdHandler.Handle(cmdConn, req) + if err != nil { + _ = cmdConn.WriteError(err, false) + } + }(conn) + } +} + +func (c *CmdDaemon) RegisterCmd(cmd string, handler CmdHandler) { + c.cmds[cmd] = handler +} + +func (c *CmdDaemon) isDebug(cmd string) bool { + + return cmd == "--debug" || cmd == "-d" +} +func (c *CmdDaemon) Run() error { + // 从命令参数中解析出是否debug 子命令和剩余参数字符串 + args := os.Args[1:] + isDebug := c.isDebug(args[0]) + var remainingArgs []string + cmd := "" + if isDebug { + if len(args) > 1 { + cmd = args[1] + remainingArgs = args[2:] + } else { + cmd = "help" + isDebug = false + remainingArgs = []string{} + } + } else { + if len(args) > 0 { + cmd = args[0] + remainingArgs = args[1:] + } else { + cmd = "help" + isDebug = false + remainingArgs = []string{} + } + } + cmdReq := CmdRequest{ + Args: strings.Join(remainingArgs, " "), + Cmd: cmd, + Id: uuid.New().String(), + IsDebug: isDebug, + } + // dial unix socket + conn, err := net.Dial("unix", c.SocketPath) + if err != nil { + return err + } + defer conn.Close() + // send cmd request + cmdReqJson, err := json.Marshal(cmdReq) + if err != nil { + return err + } + _, err = conn.Write(cmdReqJson) + if err != nil { + return err + } + + for { + + resp, err := Read[CmdResponse](conn) + if err != nil { + return err + } + if resp.Error != "" { + fmt.Println(resp.Error) + } + if !resp.Continue { + break + } + } + return nil + +}