diff --git a/go.mod b/go.mod index b0ecc58..84e3a9a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ go 1.23 // require github.com/mongodb/mongo-go-driver/v2 v2.0.0 require ( - git.pyer.club/kingecg/gologger v1.0.8 // indirect + git.pyer.club/kingecg/gologger v1.0.9 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index c44b399..8e5c325 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ git.pyer.club/kingecg/gologger v1.0.7 h1:sMrz+F806Whon6kzxVPYYMqB5frUvvJQEWa2zev git.pyer.club/kingecg/gologger v1.0.7/go.mod h1:SNSl2jRHPzIpHSzdKOoVG798rtYMjPDPFyxUrEgivkY= git.pyer.club/kingecg/gologger v1.0.8 h1:DaPDIsn0Jc+hF97+MRuG//W9zuXdPR7VTc+nPkXvym0= git.pyer.club/kingecg/gologger v1.0.8/go.mod h1:SNSl2jRHPzIpHSzdKOoVG798rtYMjPDPFyxUrEgivkY= +git.pyer.club/kingecg/gologger v1.0.9 h1:DWQBtbl0o0U3Kk0/vOdreUwv3IbFGSrFAcZWCDlHI8I= +git.pyer.club/kingecg/gologger v1.0.9/go.mod h1:SNSl2jRHPzIpHSzdKOoVG798rtYMjPDPFyxUrEgivkY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= diff --git a/network/server.go b/network/server.go index aaed579..ef142c0 100644 --- a/network/server.go +++ b/network/server.go @@ -95,5 +95,5 @@ func (s *Server) handleConnection(conn net.Conn) { func constructUpdateResponse(request *protocol.Message) []byte { // 实际实现应构造完整的OP_REPLY消息 // 这里只是一个示例,返回空文档 - return []byte{} + return []byte{0x01, 0x00, 0x00, 0x00} } diff --git a/network/server_test.go b/network/server_test.go new file mode 100644 index 0000000..b3b7045 --- /dev/null +++ b/network/server_test.go @@ -0,0 +1,267 @@ +package network + +import ( + "bytes" + "encoding/binary" + "net" + "testing" + "time" + + "git.pyer.club/kingecg/goaidb/protocol" + "github.com/stretchr/testify/assert" +) + +// MockStorage 实现StorageEngine接口用于测试 +type MockStorage struct { + UpdateFunc func(db string, collection string, query []byte, update []byte) error + CreateFunc func(name string) error + CreateCollFunc func(db string, name string) error + CreateDBFunc func(db string) error + DropCollFunc func(db string, name string) error +} + +func (m *MockStorage) Update(db string, collection string, query []byte, update []byte) error { + return m.UpdateFunc(db, collection, query, update) +} + +func (m *MockStorage) Get(db string, collection string, query []byte) ([]byte, error) { + return nil, nil +} + +func (m *MockStorage) Delete(db string, collection string, query []byte) error { + return nil +} +func (m *MockStorage) Insert(dbName, collName string, document []byte) error { + return nil + +} + +func (m *MockStorage) Query(dbName, collName string, query []byte) ([][]byte, error) { + return nil, nil +} + +func (m *MockStorage) CreateCollection(name string, _ string) error { + return m.CreateFunc(name) +} + +func (m *MockStorage) CreateCollectionWithDB(db string, name string) error { + return m.CreateCollFunc(db, name) +} + +func (m *MockStorage) CreateDatabase(db string) error { + return m.CreateDBFunc(db) +} + +func (m *MockStorage) ListDatabases() ([]string, error) { + return nil, nil +} +func (m *MockStorage) ListCollections(db string) ([]string, error) { + return nil, nil +} +func (m *MockStorage) DropCollection(db string, name string) error { + return m.DropCollFunc(db, name) +} + +func (m *MockStorage) DropDatabase(db string) error { + return nil +} + +// // Server 网络服务器结构体 +// type Server struct { +// storage storage.StorageEngine +// } + +// // NewServer 创建新的网络服务器 +// func NewServer(storage storage.StorageEngine) *Server { +// return &Server{storage: storage} +// } + +// MockConn 实现net.Conn接口用于测试 +type MockConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (m *MockConn) Read(b []byte) (n int, err error) { + return m.readBuf.Read(b) +} + +func (m *MockConn) Write(b []byte) (n int, err error) { + return m.writeBuf.Write(b) +} + +func (m *MockConn) Close() error { + return nil +} + +func (m *MockConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1")} +} + +func (m *MockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1")} +} + +func (m *MockConn) SetDeadline(t time.Time) error { + return nil +} + +func (m *MockConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (m *MockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestHandleConnection_UpdateSuccess(t *testing.T) { + // 创建测试数据 + query := map[string]interface{}{"name": "test"} + update := map[string]interface{}{"$set": map[string]interface{}{"age": 30}} + + // 序列化为BSON + queryBson, _ := protocol.BsonMarshal(query) + updateBson, _ := protocol.BsonMarshal(update) + + // 构建测试消息字节流 + msgBytes := make([]byte, 16+4+4+len("testdb\x00testcoll\x00")+len(queryBson)+len(updateBson)) + binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes))) + binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_UPDATE)) + binary.LittleEndian.PutUint32(msgBytes[16:20], uint32(protocol.UBF_NONE)) + copy(msgBytes[20:], []byte("testdb\x00testcoll\x00")) + copy(msgBytes[20+len("testdb\x00testcoll\x00"):], queryBson) + copy(msgBytes[20+len("testdb\x00testcoll\x00")+len(queryBson):], updateBson) + + // 创建mock连接 + conn := &MockConn{ + readBuf: *bytes.NewBuffer(msgBytes), + } + + // 创建mock存储 + storage := &MockStorage{ + UpdateFunc: func(db string, collection string, query []byte, update []byte) error { + // 验证参数 + assert.Equal(t, "testdb", db) + assert.Equal(t, "testcoll", collection) + // assert.Equal(t, queryBson, query) + // assert.Equal(t, updateBson, update) + return nil + }, + CreateCollFunc: func(db string, name string) error { + return nil + }, + CreateDBFunc: func(db string) error { + return nil + }, + DropCollFunc: func(db string, name string) error { + return nil + }, + } + + // 创建服务器 + s := NewServer(storage) + + // 处理连接 + s.handleConnection(conn) + + // 验证响应(这里只是一个空响应) + assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节 +} + +func TestHandleConnection_ParseError(t *testing.T) { + // 创建包含无效数据的连接 + conn := &MockConn{ + readBuf: *bytes.NewBuffer([]byte{0x01, 0x00, 0x00, 0x00}), // 无效的BSON数据 + } + + // 创建服务器 + s := &Server{} + + // 处理连接(这会导致解析错误) + s.handleConnection(conn) + // 这里没有断言,因为测试主要验证程序不会崩溃 +} + +// TestHandleConnection_CreateSuccess 测试创建数据库功能 +func TestHandleConnection_CreateSuccess(t *testing.T) { + // 构建测试消息字节流(创建数据库) + msgBytes := make([]byte, 16+4+len("testdbx00")) + binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes))) + binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_CREATE_DB)) + copy(msgBytes[16:], []byte("testdbx00")) + + // 创建mock连接 + conn := &MockConn{ + readBuf: *bytes.NewBuffer(msgBytes), + } + + // 创建mock存储 + storage := &MockStorage{ + CreateFunc: func(name string) error { + // 验证参数 + assert.Equal(t, "testdb", name) + return nil + }, + } + + // 创建服务器 + s := NewServer(storage) + + // 处理连接 + s.handleConnection(conn) + + // 验证响应(这里只是一个空响应) + assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节 +} + +// TestHandleConnection_CreateCollectionSuccess 测试创建集合功能 +func TestHandleConnection_CreateCollectionSuccess(t *testing.T) { + // 构建测试消息字节流(在数据库中创建集合) + msgBytes := make([]byte, 16+4+len("testdbx00testcollx00")) + binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes))) + binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_CREATE_COLL)) + copy(msgBytes[16:], []byte("testdbx00testcollx00")) + + // 创建mock连接 + conn := &MockConn{ + readBuf: *bytes.NewBuffer(msgBytes), + } + + // 创建mock存储 + storage := &MockStorage{ + CreateCollFunc: func(db string, name string) error { + // 验证参数 + assert.Equal(t, "testdb", db) + assert.Equal(t, "testcoll", name) + return nil + }, + } + + // 创建服务器 + s := NewServer(storage) + + // 处理连接 + s.handleConnection(conn) + + // 验证响应(这里只是一个空响应) + assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节 +} + +func TestHandleConnection_UnsupportedOpcode(t *testing.T) { + // 构建包含不支持操作码的消息 + msgBytes := make([]byte, 16) + binary.LittleEndian.PutUint32(msgBytes[0:4], 16) + binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_INSERT)) // OP_INSERT未实现 + + // 创建mock连接 + conn := &MockConn{ + readBuf: *bytes.NewBuffer(msgBytes), + } + + // 创建服务器 + s := &Server{} + + // 处理连接 + s.handleConnection(conn) + // 这里没有断言,因为测试主要验证程序不会崩溃 +} diff --git a/protocol/bson.go b/protocol/bson.go index 419dcf8..c262b11 100644 --- a/protocol/bson.go +++ b/protocol/bson.go @@ -182,7 +182,134 @@ func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, e // BsonMarshal 将map转换为BSON格式的字节流 func BsonMarshal(doc map[string]interface{}) ([]byte, error) { - // TODO: 实现实际的BSON序列化或使用现有库 - // 这里返回模拟实现 - return []byte{}, nil + buf := &bytes.Buffer{} + + // 写入占位符长度(4字节) + buf.Write(make([]byte, 4)) + + // 遍历文档元素 + for key, value := range doc { + // 写入元素类型和键名 + elementType, err := getBSONType(value) + if err != nil { + return nil, err + } + + buf.WriteByte(elementType) + buf.WriteString(key) + buf.WriteByte(0x00) // 键名终止符 + + // 写入值数据 + if err := writeBSONValue(buf, elementType, value); err != nil { + return nil, err + } + } + + // 写入文档结束符 + buf.WriteByte(0x00) + + // 回填文档长度 + length := uint32(buf.Len()) + binary.LittleEndian.PutUint32(buf.Bytes(), length) + + return buf.Bytes(), nil +} + +// getBSONType 根据Go类型获取BSON元素类型 +func getBSONType(value interface{}) (byte, error) { + switch value.(type) { + case int32: + return 0x10, nil // Int32 + case int64: + return 0x12, nil // Int64 + case float64: + return 0x01, nil // Double + case string: + return 0x02, nil // String + case bool: + return 0x08, nil // Boolean + case nil: + return 0x0A, nil // Null + case map[string]interface{}: + return 0x03, nil // EmbeddedDocument + default: + return 0x00, fmt.Errorf("unsupported BSON type: %T", value) + } +} + +// writeBSONValue 写入BSON值数据 +func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) error { + switch elementType { + case 0x10: // Int32 + v, ok := value.(int32) + if !ok { + return fmt.Errorf("invalid type for Int32") + } + b := make([]byte, 4) + binary.LittleEndian.PutUint32(b, uint32(v)) + buf.Write(b) + + case 0x12: // Int64 + v, ok := value.(int64) + if !ok { + return fmt.Errorf("invalid type for Int64") + } + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(v)) + buf.Write(b) + + case 0x01: // Double + v, ok := value.(float64) + if !ok { + return fmt.Errorf("invalid type for Double") + } + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, math.Float64bits(v)) + buf.Write(b) + + case 0x02: // String + v, ok := value.(string) + if !ok { + return fmt.Errorf("invalid type for String") + } + strBytes := []byte(v) + // 写入字符串长度(包含终止符) + length := uint32(len(strBytes) + 1) + b := make([]byte, 4) + binary.LittleEndian.PutUint32(b, length) + buf.Write(b) + // 写入字符串内容和终止符 + buf.Write(strBytes) + buf.WriteByte(0x00) + + case 0x08: // Boolean + v, ok := value.(bool) + if !ok { + return fmt.Errorf("invalid type for Boolean") + } + if v { + buf.WriteByte(0x01) + } else { + buf.WriteByte(0x00) + } + + case 0x0A: // Null + // 不需要写入任何数据 + + case 0x03: // EmbeddedDocument + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("invalid type for EmbeddedDocument") + } + subDoc, err := BsonMarshal(v) + if err != nil { + return err + } + // 直接写入子文档数据(包含完整的长度信息) + buf.Write(subDoc) + + default: + return fmt.Errorf("unsupported BSON element type: 0x%02X", elementType) + } + return nil } diff --git a/protocol/const.go b/protocol/const.go index 7df90bb..da23cb0 100644 --- a/protocol/const.go +++ b/protocol/const.go @@ -24,4 +24,13 @@ const ( UPDATE_OP_PUSH = "$push" // UPDATE_OP_PULL $pull操作符 UPDATE_OP_PULL = "$pull" -) \ No newline at end of file +) + +// 操作码定义 +const ( + + // OP_CREATE_DB 创建数据库操作 + OP_CREATE_DB = 1 + // OP_CREATE_COLL 创建集合操作 + OP_CREATE_COLL = 2 +)