diff --git a/config/config.go b/config/config.go index 027abc7..3a24c1b 100644 --- a/config/config.go +++ b/config/config.go @@ -6,10 +6,12 @@ import ( "os" "path/filepath" "sync" + "log" "encoding/json" "gopkg.in/yaml.v2" + "github.com/fsnotify/fsnotify" ) // Config 系统配置结构体 @@ -33,6 +35,7 @@ type StorageConfig struct { var ( configInstance *Config once sync.Once + watcher *fsnotify.Watcher // 配置文件监视器 ) // NewDefaultConfig 创建默认配置 @@ -82,6 +85,89 @@ func ParseConfig(filePath string) (*Config, error) { return &cfg, nil } +// ValidateConfig 验证配置有效性 +// 返回验证结果和第一个发现的错误(如果有) +func (c *Config) Validate() (bool, error) { + // 验证服务器配置 + if c.Server.Port < 1024 || c.Server.Port > 65535 { + return false, fmt.Errorf("端口必须在1024-65535之间: %d", c.Server.Port) + } + + // 验证存储配置 + if c.Storage.Engine != "memory" && c.Storage.Engine != "rocksdb" { + return false, fmt.Errorf("不支持的存储引擎: %s", c.Storage.Engine) + } + + // 验证数据路径是否存在 + if _, err := os.Stat(c.Storage.DataPath); os.IsNotExist(err) { + return false, fmt.Errorf("数据目录不存在: %s", c.Storage.DataPath) + } + + return true, nil +} + +// WatchConfig 监听配置文件变化并自动重载 +func WatchConfig(filePath string) error { + // 创建文件监视器 + w, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("创建文件监视器失败: %v", err) + } + + // 添加文件监视 + dir := filepath.Dir(filePath) + fileName := filepath.Base(filePath) + + err = w.Add(dir) + if err != nil { + return fmt.Errorf("添加文件监视失败: %v", err) + } + + // 启动监听协程 + go func() { + for { + select { + case event, ok := <-w.Events: + if !ok { + return + } + // 当配置文件被修改或创建时重新加载 + if (event.Op&fsnotify.Write == fsnotify.Write || + event.Op&fsnotify.Create == fsnotify.Create) && + filepath.Base(event.Name) == fileName { + + // 重新加载配置 + newCfg, err := ParseConfig(filePath) + if err != nil { + log.Printf("配置重载失败: %v", err) + continue + } + + // 更新全局配置 + SetGlobalConfig(newCfg) + log.Printf("配置已更新") + } + + case err, ok := <-w.Errors: + if !ok { + return + } + log.Printf("文件监视错误: %v", err) + } + } + }() + + watcher = w + return nil +} + +// CloseWatcher 关闭配置文件监视器 +func CloseWatcher() { + if watcher != nil { + watcher.Close() + } +} + // GetConfig 获取全局配置实例 func GetConfig() *Config { once.Do(func() { @@ -89,3 +175,10 @@ func GetConfig() *Config { }) return configInstance } + +// SetGlobalConfig 设置全局配置实例 +func SetGlobalConfig(cfg *Config) { + once.Do(func() { + configInstance = cfg + }) +} diff --git a/config/config_test.go b/config/config_test.go index b04731e..61a7bd5 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -102,3 +102,82 @@ func TestParseConfig_FileNotFound(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "配置文件不存在") } + +func TestValidateConfig_InvalidPort(t *testing.T) { + // 创建临时YAML配置文件 + tmpFile, err := os.CreateTemp("", "*.yaml") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // 写入无效端口配置 + testConfig := `server: + host: 127.0.0.1 + port: 1000 +storage: + engine: rocksdb + dataPath: /test/data` + + _, err = tmpFile.WriteString(testConfig) + assert.NoError(t, err) + tmpFile.Close() + + // 测试解析功能 + cfg, err := ParseConfig(tmpFile.Name()) + assert.NoError(t, err) + assert.NotNil(t, cfg) + valid, err := cfg.Validate() + assert.False(t, valid) + assert.Contains(t, err.Error(), "端口必须在1024-65535之间") +} + +func TestValidateConfig_InvalidStorageEngine(t *testing.T) { + // 创建临时YAML配置文件 + tmpFile, err := os.CreateTemp("", "*.yaml") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // 写入无效存储引擎配置 + testConfig := `server: + host: 127.0.0.1 + port: 8080 +storage: + engine: invalid_engine + dataPath: /test/data` + + _, err = tmpFile.WriteString(testConfig) + assert.NoError(t, err) + tmpFile.Close() + + // 测试解析功能 + cfg, err := ParseConfig(tmpFile.Name()) + assert.NoError(t, err) + valid, err := cfg.Validate() + assert.False(t, valid) + assert.Contains(t, err.Error(), "不支持的存储引擎") +} + +func TestValidateConfig_DataPathNotExists(t *testing.T) { + // 创建临时YAML配置文件 + tmpFile, err := os.CreateTemp("", "*.yaml") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // 写入无效数据路径配置 + testConfig := `server: + host: 127.0.0.1 + port: 8080 +storage: + engine: rocksdb + dataPath: /invalid/path` + + _, err = tmpFile.WriteString(testConfig) + assert.NoError(t, err) + tmpFile.Close() + + // 测试解析功能 + cfg, err := ParseConfig(tmpFile.Name()) + assert.NoError(t, err) + valid, err := cfg.Validate() + assert.False(t, valid) + assert.Contains(t, err.Error(), "数据目录不存在") +} diff --git a/go.mod b/go.mod index 37ad9ec..4bbd95d 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,10 @@ go 1.23 require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.10.0 // indirect + golang.org/x/sys v0.13.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e8d4fcb..bc3c296 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,13 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/main.go b/main.go index 27061f1..5e60a2c 100644 --- a/main.go +++ b/main.go @@ -3,12 +3,64 @@ package main import ( "fmt" "net" + "os" + "os/signal" + "path/filepath" + + "syscall" + + "github.com/kingecg/goaidb/config" "github.com/kingecg/goaidb/network" "github.com/kingecg/goaidb/storage" ) +func getExeDir() string { + exePath, err := os.Executable() + if err != nil { + panic(err) + } + exeDir := filepath.Dir(exePath) + return exeDir +} + // 主程序入口 func main() { + // 加载配置文件(优先级:命令行参数 > 环境变量 > 默认配置) + configPath := getExeDir() + "/config.yaml" + if len(os.Args) > 1 { + configPath = os.Args[1] + } else if envPath := os.Getenv("GOAIDB_CONFIG"); envPath != "" { + configPath = envPath + } + + // 解析配置文件 + cfg, err := config.ParseConfig(configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "配置加载失败: %v\n", err) + os.Exit(1) + } + + // 设置全局配置 + config.SetGlobalConfig(cfg) + + // 启动配置文件监视 + err = config.WatchConfig(configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "配置监控启动失败: %v\n", err) + os.Exit(1) + } + + // 设置信号处理(优雅关闭) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-sigChan + fmt.Println("正在关闭...") + config.CloseWatcher() + os.Exit(0) + }() + // 初始化存储引擎(默认使用内存引擎) storageEngine, err := storage.NewMemoryEngine() if err != nil { @@ -18,14 +70,14 @@ func main() { // 创建网络服务器 server := network.NewServer(storageEngine) - + conf := config.GetConfig() // 启动服务 - listener, err := net.Listen("tcp", ":27017") + listener, err := net.Listen("tcp", conf.Server.Host+":"+fmt.Sprintf("%d", conf.Server.Port)) if err != nil { fmt.Printf("Failed to start server: %v\n", err) return } - + fmt.Println("GoAIDB started on port 27017") server.Serve(listener) -} \ No newline at end of file +}