实现config加载逻辑

This commit is contained in:
kingecg 2025-06-05 21:05:36 +08:00
parent 0b3002e395
commit a0826d5361
5 changed files with 234 additions and 4 deletions

View File

@ -6,10 +6,12 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
"log"
"encoding/json" "encoding/json"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/fsnotify/fsnotify"
) )
// Config 系统配置结构体 // Config 系统配置结构体
@ -33,6 +35,7 @@ type StorageConfig struct {
var ( var (
configInstance *Config configInstance *Config
once sync.Once once sync.Once
watcher *fsnotify.Watcher // 配置文件监视器
) )
// NewDefaultConfig 创建默认配置 // NewDefaultConfig 创建默认配置
@ -82,6 +85,89 @@ func ParseConfig(filePath string) (*Config, error) {
return &cfg, nil return &cfg, nil
} }
// ValidateConfig 验证配置有效性
// 返回验证结果和第一个发现的错误(如果有)
func (c *Config) Validate() (bool, error) {
// 验证服务器配置
if c.Server.Port < 1024 || c.Server.Port > 65535 {
return false, fmt.Errorf("端口必须在1024-65535之间: %d", c.Server.Port)
}
// 验证存储配置
if c.Storage.Engine != "memory" && c.Storage.Engine != "rocksdb" {
return false, fmt.Errorf("不支持的存储引擎: %s", c.Storage.Engine)
}
// 验证数据路径是否存在
if _, err := os.Stat(c.Storage.DataPath); os.IsNotExist(err) {
return false, fmt.Errorf("数据目录不存在: %s", c.Storage.DataPath)
}
return true, nil
}
// WatchConfig 监听配置文件变化并自动重载
func WatchConfig(filePath string) error {
// 创建文件监视器
w, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("创建文件监视器失败: %v", err)
}
// 添加文件监视
dir := filepath.Dir(filePath)
fileName := filepath.Base(filePath)
err = w.Add(dir)
if err != nil {
return fmt.Errorf("添加文件监视失败: %v", err)
}
// 启动监听协程
go func() {
for {
select {
case event, ok := <-w.Events:
if !ok {
return
}
// 当配置文件被修改或创建时重新加载
if (event.Op&fsnotify.Write == fsnotify.Write ||
event.Op&fsnotify.Create == fsnotify.Create) &&
filepath.Base(event.Name) == fileName {
// 重新加载配置
newCfg, err := ParseConfig(filePath)
if err != nil {
log.Printf("配置重载失败: %v", err)
continue
}
// 更新全局配置
SetGlobalConfig(newCfg)
log.Printf("配置已更新")
}
case err, ok := <-w.Errors:
if !ok {
return
}
log.Printf("文件监视错误: %v", err)
}
}
}()
watcher = w
return nil
}
// CloseWatcher 关闭配置文件监视器
func CloseWatcher() {
if watcher != nil {
watcher.Close()
}
}
// GetConfig 获取全局配置实例 // GetConfig 获取全局配置实例
func GetConfig() *Config { func GetConfig() *Config {
once.Do(func() { once.Do(func() {
@ -89,3 +175,10 @@ func GetConfig() *Config {
}) })
return configInstance return configInstance
} }
// SetGlobalConfig 设置全局配置实例
func SetGlobalConfig(cfg *Config) {
once.Do(func() {
configInstance = cfg
})
}

View File

@ -102,3 +102,82 @@ func TestParseConfig_FileNotFound(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "配置文件不存在") assert.Contains(t, err.Error(), "配置文件不存在")
} }
func TestValidateConfig_InvalidPort(t *testing.T) {
// 创建临时YAML配置文件
tmpFile, err := os.CreateTemp("", "*.yaml")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
// 写入无效端口配置
testConfig := `server:
host: 127.0.0.1
port: 1000
storage:
engine: rocksdb
dataPath: /test/data`
_, err = tmpFile.WriteString(testConfig)
assert.NoError(t, err)
tmpFile.Close()
// 测试解析功能
cfg, err := ParseConfig(tmpFile.Name())
assert.NoError(t, err)
assert.NotNil(t, cfg)
valid, err := cfg.Validate()
assert.False(t, valid)
assert.Contains(t, err.Error(), "端口必须在1024-65535之间")
}
func TestValidateConfig_InvalidStorageEngine(t *testing.T) {
// 创建临时YAML配置文件
tmpFile, err := os.CreateTemp("", "*.yaml")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
// 写入无效存储引擎配置
testConfig := `server:
host: 127.0.0.1
port: 8080
storage:
engine: invalid_engine
dataPath: /test/data`
_, err = tmpFile.WriteString(testConfig)
assert.NoError(t, err)
tmpFile.Close()
// 测试解析功能
cfg, err := ParseConfig(tmpFile.Name())
assert.NoError(t, err)
valid, err := cfg.Validate()
assert.False(t, valid)
assert.Contains(t, err.Error(), "不支持的存储引擎")
}
func TestValidateConfig_DataPathNotExists(t *testing.T) {
// 创建临时YAML配置文件
tmpFile, err := os.CreateTemp("", "*.yaml")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
// 写入无效数据路径配置
testConfig := `server:
host: 127.0.0.1
port: 8080
storage:
engine: rocksdb
dataPath: /invalid/path`
_, err = tmpFile.WriteString(testConfig)
assert.NoError(t, err)
tmpFile.Close()
// 测试解析功能
cfg, err := ParseConfig(tmpFile.Name())
assert.NoError(t, err)
valid, err := cfg.Validate()
assert.False(t, valid)
assert.Contains(t, err.Error(), "数据目录不存在")
}

2
go.mod
View File

@ -8,8 +8,10 @@ go 1.23
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect github.com/stretchr/testify v1.10.0 // indirect
golang.org/x/sys v0.13.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

4
go.sum
View File

@ -1,9 +1,13 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

56
main.go
View File

@ -3,12 +3,64 @@ package main
import ( import (
"fmt" "fmt"
"net" "net"
"os"
"os/signal"
"path/filepath"
"syscall"
"github.com/kingecg/goaidb/config"
"github.com/kingecg/goaidb/network" "github.com/kingecg/goaidb/network"
"github.com/kingecg/goaidb/storage" "github.com/kingecg/goaidb/storage"
) )
func getExeDir() string {
exePath, err := os.Executable()
if err != nil {
panic(err)
}
exeDir := filepath.Dir(exePath)
return exeDir
}
// 主程序入口 // 主程序入口
func main() { func main() {
// 加载配置文件(优先级:命令行参数 > 环境变量 > 默认配置)
configPath := getExeDir() + "/config.yaml"
if len(os.Args) > 1 {
configPath = os.Args[1]
} else if envPath := os.Getenv("GOAIDB_CONFIG"); envPath != "" {
configPath = envPath
}
// 解析配置文件
cfg, err := config.ParseConfig(configPath)
if err != nil {
fmt.Fprintf(os.Stderr, "配置加载失败: %v\n", err)
os.Exit(1)
}
// 设置全局配置
config.SetGlobalConfig(cfg)
// 启动配置文件监视
err = config.WatchConfig(configPath)
if err != nil {
fmt.Fprintf(os.Stderr, "配置监控启动失败: %v\n", err)
os.Exit(1)
}
// 设置信号处理(优雅关闭)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
fmt.Println("正在关闭...")
config.CloseWatcher()
os.Exit(0)
}()
// 初始化存储引擎(默认使用内存引擎) // 初始化存储引擎(默认使用内存引擎)
storageEngine, err := storage.NewMemoryEngine() storageEngine, err := storage.NewMemoryEngine()
if err != nil { if err != nil {
@ -18,9 +70,9 @@ func main() {
// 创建网络服务器 // 创建网络服务器
server := network.NewServer(storageEngine) server := network.NewServer(storageEngine)
conf := config.GetConfig()
// 启动服务 // 启动服务
listener, err := net.Listen("tcp", ":27017") listener, err := net.Listen("tcp", conf.Server.Host+":"+fmt.Sprintf("%d", conf.Server.Port))
if err != nil { if err != nil {
fmt.Printf("Failed to start server: %v\n", err) fmt.Printf("Failed to start server: %v\n", err)
return return