实现config加载逻辑
This commit is contained in:
parent
0b3002e395
commit
a0826d5361
|
@ -6,10 +6,12 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"log"
|
||||
|
||||
"encoding/json"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// Config 系统配置结构体
|
||||
|
@ -33,6 +35,7 @@ type StorageConfig struct {
|
|||
var (
|
||||
configInstance *Config
|
||||
once sync.Once
|
||||
watcher *fsnotify.Watcher // 配置文件监视器
|
||||
)
|
||||
|
||||
// NewDefaultConfig 创建默认配置
|
||||
|
@ -82,6 +85,89 @@ func ParseConfig(filePath string) (*Config, error) {
|
|||
return &cfg, nil
|
||||
}
|
||||
|
||||
// ValidateConfig 验证配置有效性
|
||||
// 返回验证结果和第一个发现的错误(如果有)
|
||||
func (c *Config) Validate() (bool, error) {
|
||||
// 验证服务器配置
|
||||
if c.Server.Port < 1024 || c.Server.Port > 65535 {
|
||||
return false, fmt.Errorf("端口必须在1024-65535之间: %d", c.Server.Port)
|
||||
}
|
||||
|
||||
// 验证存储配置
|
||||
if c.Storage.Engine != "memory" && c.Storage.Engine != "rocksdb" {
|
||||
return false, fmt.Errorf("不支持的存储引擎: %s", c.Storage.Engine)
|
||||
}
|
||||
|
||||
// 验证数据路径是否存在
|
||||
if _, err := os.Stat(c.Storage.DataPath); os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("数据目录不存在: %s", c.Storage.DataPath)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// WatchConfig 监听配置文件变化并自动重载
|
||||
func WatchConfig(filePath string) error {
|
||||
// 创建文件监视器
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建文件监视器失败: %v", err)
|
||||
}
|
||||
|
||||
// 添加文件监视
|
||||
dir := filepath.Dir(filePath)
|
||||
fileName := filepath.Base(filePath)
|
||||
|
||||
err = w.Add(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加文件监视失败: %v", err)
|
||||
}
|
||||
|
||||
// 启动监听协程
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-w.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// 当配置文件被修改或创建时重新加载
|
||||
if (event.Op&fsnotify.Write == fsnotify.Write ||
|
||||
event.Op&fsnotify.Create == fsnotify.Create) &&
|
||||
filepath.Base(event.Name) == fileName {
|
||||
|
||||
// 重新加载配置
|
||||
newCfg, err := ParseConfig(filePath)
|
||||
if err != nil {
|
||||
log.Printf("配置重载失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新全局配置
|
||||
SetGlobalConfig(newCfg)
|
||||
log.Printf("配置已更新")
|
||||
}
|
||||
|
||||
case err, ok := <-w.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("文件监视错误: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
watcher = w
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWatcher 关闭配置文件监视器
|
||||
func CloseWatcher() {
|
||||
if watcher != nil {
|
||||
watcher.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig 获取全局配置实例
|
||||
func GetConfig() *Config {
|
||||
once.Do(func() {
|
||||
|
@ -89,3 +175,10 @@ func GetConfig() *Config {
|
|||
})
|
||||
return configInstance
|
||||
}
|
||||
|
||||
// SetGlobalConfig 设置全局配置实例
|
||||
func SetGlobalConfig(cfg *Config) {
|
||||
once.Do(func() {
|
||||
configInstance = cfg
|
||||
})
|
||||
}
|
||||
|
|
|
@ -102,3 +102,82 @@ func TestParseConfig_FileNotFound(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "配置文件不存在")
|
||||
}
|
||||
|
||||
func TestValidateConfig_InvalidPort(t *testing.T) {
|
||||
// 创建临时YAML配置文件
|
||||
tmpFile, err := os.CreateTemp("", "*.yaml")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// 写入无效端口配置
|
||||
testConfig := `server:
|
||||
host: 127.0.0.1
|
||||
port: 1000
|
||||
storage:
|
||||
engine: rocksdb
|
||||
dataPath: /test/data`
|
||||
|
||||
_, err = tmpFile.WriteString(testConfig)
|
||||
assert.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
// 测试解析功能
|
||||
cfg, err := ParseConfig(tmpFile.Name())
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
valid, err := cfg.Validate()
|
||||
assert.False(t, valid)
|
||||
assert.Contains(t, err.Error(), "端口必须在1024-65535之间")
|
||||
}
|
||||
|
||||
func TestValidateConfig_InvalidStorageEngine(t *testing.T) {
|
||||
// 创建临时YAML配置文件
|
||||
tmpFile, err := os.CreateTemp("", "*.yaml")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// 写入无效存储引擎配置
|
||||
testConfig := `server:
|
||||
host: 127.0.0.1
|
||||
port: 8080
|
||||
storage:
|
||||
engine: invalid_engine
|
||||
dataPath: /test/data`
|
||||
|
||||
_, err = tmpFile.WriteString(testConfig)
|
||||
assert.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
// 测试解析功能
|
||||
cfg, err := ParseConfig(tmpFile.Name())
|
||||
assert.NoError(t, err)
|
||||
valid, err := cfg.Validate()
|
||||
assert.False(t, valid)
|
||||
assert.Contains(t, err.Error(), "不支持的存储引擎")
|
||||
}
|
||||
|
||||
func TestValidateConfig_DataPathNotExists(t *testing.T) {
|
||||
// 创建临时YAML配置文件
|
||||
tmpFile, err := os.CreateTemp("", "*.yaml")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// 写入无效数据路径配置
|
||||
testConfig := `server:
|
||||
host: 127.0.0.1
|
||||
port: 8080
|
||||
storage:
|
||||
engine: rocksdb
|
||||
dataPath: /invalid/path`
|
||||
|
||||
_, err = tmpFile.WriteString(testConfig)
|
||||
assert.NoError(t, err)
|
||||
tmpFile.Close()
|
||||
|
||||
// 测试解析功能
|
||||
cfg, err := ParseConfig(tmpFile.Name())
|
||||
assert.NoError(t, err)
|
||||
valid, err := cfg.Validate()
|
||||
assert.False(t, valid)
|
||||
assert.Contains(t, err.Error(), "数据目录不存在")
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -8,8 +8,10 @@ go 1.23
|
|||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/testify v1.10.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
4
go.sum
4
go.sum
|
@ -1,9 +1,13 @@
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
|
|
60
main.go
60
main.go
|
@ -3,12 +3,64 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
|
||||
"syscall"
|
||||
|
||||
"github.com/kingecg/goaidb/config"
|
||||
"github.com/kingecg/goaidb/network"
|
||||
"github.com/kingecg/goaidb/storage"
|
||||
)
|
||||
|
||||
func getExeDir() string {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
exeDir := filepath.Dir(exePath)
|
||||
return exeDir
|
||||
}
|
||||
|
||||
// 主程序入口
|
||||
func main() {
|
||||
// 加载配置文件(优先级:命令行参数 > 环境变量 > 默认配置)
|
||||
configPath := getExeDir() + "/config.yaml"
|
||||
if len(os.Args) > 1 {
|
||||
configPath = os.Args[1]
|
||||
} else if envPath := os.Getenv("GOAIDB_CONFIG"); envPath != "" {
|
||||
configPath = envPath
|
||||
}
|
||||
|
||||
// 解析配置文件
|
||||
cfg, err := config.ParseConfig(configPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "配置加载失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 设置全局配置
|
||||
config.SetGlobalConfig(cfg)
|
||||
|
||||
// 启动配置文件监视
|
||||
err = config.WatchConfig(configPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "配置监控启动失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 设置信号处理(优雅关闭)
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("正在关闭...")
|
||||
config.CloseWatcher()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
// 初始化存储引擎(默认使用内存引擎)
|
||||
storageEngine, err := storage.NewMemoryEngine()
|
||||
if err != nil {
|
||||
|
@ -18,14 +70,14 @@ func main() {
|
|||
|
||||
// 创建网络服务器
|
||||
server := network.NewServer(storageEngine)
|
||||
|
||||
conf := config.GetConfig()
|
||||
// 启动服务
|
||||
listener, err := net.Listen("tcp", ":27017")
|
||||
listener, err := net.Listen("tcp", conf.Server.Host+":"+fmt.Sprintf("%d", conf.Server.Port))
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to start server: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
fmt.Println("GoAIDB started on port 27017")
|
||||
server.Serve(listener)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue