package network import ( "bytes" "encoding/binary" "net" "testing" "time" "git.pyer.club/kingecg/goaidb/protocol" "github.com/stretchr/testify/assert" ) // MockStorage 实现StorageEngine接口用于测试 type MockStorage struct { UpdateFunc func(db string, collection string, query []byte, update []byte) error CreateFunc func(name string) error CreateCollFunc func(db string, name string) error CreateDBFunc func(db string) error DropCollFunc func(db string, name string) error } func (m *MockStorage) Update(db string, collection string, query []byte, update []byte) error { return m.UpdateFunc(db, collection, query, update) } func (m *MockStorage) Get(db string, collection string, query []byte) ([]byte, error) { return nil, nil } func (m *MockStorage) Delete(db string, collection string, query []byte) error { return nil } func (m *MockStorage) Insert(dbName, collName string, document []byte) error { return nil } func (m *MockStorage) Query(dbName, collName string, query []byte) ([][]byte, error) { return nil, nil } func (m *MockStorage) CreateCollection(name string, _ string) error { return m.CreateFunc(name) } func (m *MockStorage) CreateCollectionWithDB(db string, name string) error { return m.CreateCollFunc(db, name) } func (m *MockStorage) CreateDatabase(db string) error { return m.CreateDBFunc(db) } func (m *MockStorage) ListDatabases() ([]string, error) { return nil, nil } func (m *MockStorage) ListCollections(db string) ([]string, error) { return nil, nil } func (m *MockStorage) DropCollection(db string, name string) error { return m.DropCollFunc(db, name) } func (m *MockStorage) DropDatabase(db string) error { return nil } // // Server 网络服务器结构体 // type Server struct { // storage storage.StorageEngine // } // // NewServer 创建新的网络服务器 // func NewServer(storage storage.StorageEngine) *Server { // return &Server{storage: storage} // } // MockConn 实现net.Conn接口用于测试 type MockConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer } func (m *MockConn) Read(b []byte) (n int, err error) { return m.readBuf.Read(b) } func (m *MockConn) Write(b []byte) (n int, err error) { return m.writeBuf.Write(b) } func (m *MockConn) Close() error { return nil } func (m *MockConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.ParseIP("127.0.0.1")} } func (m *MockConn) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.ParseIP("127.0.0.1")} } func (m *MockConn) SetDeadline(t time.Time) error { return nil } func (m *MockConn) SetReadDeadline(t time.Time) error { return nil } func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil } func TestHandleConnection_UpdateSuccess(t *testing.T) { // 创建测试数据 query := map[string]interface{}{"name": "test"} update := map[string]interface{}{"$set": map[string]interface{}{"age": 30}} // 序列化为BSON queryBson, _ := protocol.BsonMarshal(query) updateBson, _ := protocol.BsonMarshal(update) // 构建测试消息字节流 msgBytes := make([]byte, 16+4+4+len("testdb\x00testcoll\x00")+len(queryBson)+len(updateBson)) binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes))) binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_UPDATE)) binary.LittleEndian.PutUint32(msgBytes[16:20], uint32(protocol.UBF_NONE)) copy(msgBytes[20:], []byte("testdb\x00testcoll\x00")) copy(msgBytes[20+len("testdb\x00testcoll\x00"):], queryBson) copy(msgBytes[20+len("testdb\x00testcoll\x00")+len(queryBson):], updateBson) // 创建mock连接 conn := &MockConn{ readBuf: *bytes.NewBuffer(msgBytes), } // 创建mock存储 storage := &MockStorage{ UpdateFunc: func(db string, collection string, query []byte, update []byte) error { // 验证参数 assert.Equal(t, "testdb", db) assert.Equal(t, "testcoll", collection) // assert.Equal(t, queryBson, query) // assert.Equal(t, updateBson, update) return nil }, CreateCollFunc: func(db string, name string) error { return nil }, CreateDBFunc: func(db string) error { return nil }, DropCollFunc: func(db string, name string) error { return nil }, } // 创建服务器 s := NewServer(storage) // 处理连接 s.handleConnection(conn) // 验证响应(这里只是一个空响应) assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节 } func TestHandleConnection_ParseError(t *testing.T) { // 创建包含无效数据的连接 conn := &MockConn{ readBuf: *bytes.NewBuffer([]byte{0x01, 0x00, 0x00, 0x00}), // 无效的BSON数据 } // 创建服务器 s := &Server{} // 处理连接(这会导致解析错误) s.handleConnection(conn) // 这里没有断言,因为测试主要验证程序不会崩溃 } // TestHandleConnection_CreateSuccess 测试创建数据库功能 func TestHandleConnection_CreateSuccess(t *testing.T) { // 构建测试消息字节流(创建数据库) msgBytes := make([]byte, 16+4+len("testdbx00")) binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes))) binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_CREATE_DB)) copy(msgBytes[16:], []byte("testdbx00")) // 创建mock连接 conn := &MockConn{ readBuf: *bytes.NewBuffer(msgBytes), } // 创建mock存储 storage := &MockStorage{ CreateFunc: func(name string) error { // 验证参数 assert.Equal(t, "testdb", name) return nil }, } // 创建服务器 s := NewServer(storage) // 处理连接 s.handleConnection(conn) // 验证响应(这里只是一个空响应) assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节 } // TestHandleConnection_CreateCollectionSuccess 测试创建集合功能 func TestHandleConnection_CreateCollectionSuccess(t *testing.T) { // 构建测试消息字节流(在数据库中创建集合) msgBytes := make([]byte, 16+4+len("testdbx00testcollx00")) binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes))) binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_CREATE_COLL)) copy(msgBytes[16:], []byte("testdbx00testcollx00")) // 创建mock连接 conn := &MockConn{ readBuf: *bytes.NewBuffer(msgBytes), } // 创建mock存储 storage := &MockStorage{ CreateCollFunc: func(db string, name string) error { // 验证参数 assert.Equal(t, "testdb", db) assert.Equal(t, "testcoll", name) return nil }, } // 创建服务器 s := NewServer(storage) // 处理连接 s.handleConnection(conn) // 验证响应(这里只是一个空响应) assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节 } func TestHandleConnection_UnsupportedOpcode(t *testing.T) { // 构建包含不支持操作码的消息 msgBytes := make([]byte, 16) binary.LittleEndian.PutUint32(msgBytes[0:4], 16) binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_INSERT)) // OP_INSERT未实现 // 创建mock连接 conn := &MockConn{ readBuf: *bytes.NewBuffer(msgBytes), } // 创建服务器 s := &Server{} // 处理连接 s.handleConnection(conn) // 这里没有断言,因为测试主要验证程序不会崩溃 }