goaidb/network/server_test.go

268 lines
7.0 KiB
Go

package network
import (
"bytes"
"encoding/binary"
"net"
"testing"
"time"
"git.pyer.club/kingecg/goaidb/protocol"
"github.com/stretchr/testify/assert"
)
// MockStorage 实现StorageEngine接口用于测试
type MockStorage struct {
UpdateFunc func(db string, collection string, query []byte, update []byte) error
CreateFunc func(name string) error
CreateCollFunc func(db string, name string) error
CreateDBFunc func(db string) error
DropCollFunc func(db string, name string) error
}
func (m *MockStorage) Update(db string, collection string, query []byte, update []byte) error {
return m.UpdateFunc(db, collection, query, update)
}
func (m *MockStorage) Get(db string, collection string, query []byte) ([]byte, error) {
return nil, nil
}
func (m *MockStorage) Delete(db string, collection string, query []byte) error {
return nil
}
func (m *MockStorage) Insert(dbName, collName string, document []byte) error {
return nil
}
func (m *MockStorage) Query(dbName, collName string, query []byte) ([][]byte, error) {
return nil, nil
}
func (m *MockStorage) CreateCollection(name string, _ string) error {
return m.CreateFunc(name)
}
func (m *MockStorage) CreateCollectionWithDB(db string, name string) error {
return m.CreateCollFunc(db, name)
}
func (m *MockStorage) CreateDatabase(db string) error {
return m.CreateDBFunc(db)
}
func (m *MockStorage) ListDatabases() ([]string, error) {
return nil, nil
}
func (m *MockStorage) ListCollections(db string) ([]string, error) {
return nil, nil
}
func (m *MockStorage) DropCollection(db string, name string) error {
return m.DropCollFunc(db, name)
}
func (m *MockStorage) DropDatabase(db string) error {
return nil
}
// // Server 网络服务器结构体
// type Server struct {
// storage storage.StorageEngine
// }
// // NewServer 创建新的网络服务器
// func NewServer(storage storage.StorageEngine) *Server {
// return &Server{storage: storage}
// }
// MockConn 实现net.Conn接口用于测试
type MockConn struct {
readBuf bytes.Buffer
writeBuf bytes.Buffer
}
func (m *MockConn) Read(b []byte) (n int, err error) {
return m.readBuf.Read(b)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
return m.writeBuf.Write(b)
}
func (m *MockConn) Close() error {
return nil
}
func (m *MockConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.ParseIP("127.0.0.1")}
}
func (m *MockConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.ParseIP("127.0.0.1")}
}
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestHandleConnection_UpdateSuccess(t *testing.T) {
// 创建测试数据
query := map[string]interface{}{"name": "test"}
update := map[string]interface{}{"$set": map[string]interface{}{"age": 30}}
// 序列化为BSON
queryBson, _ := protocol.BsonMarshal(query)
updateBson, _ := protocol.BsonMarshal(update)
// 构建测试消息字节流
msgBytes := make([]byte, 16+4+4+len("testdb\x00testcoll\x00")+len(queryBson)+len(updateBson))
binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes)))
binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_UPDATE))
binary.LittleEndian.PutUint32(msgBytes[16:20], uint32(protocol.UBF_NONE))
copy(msgBytes[20:], []byte("testdb\x00testcoll\x00"))
copy(msgBytes[20+len("testdb\x00testcoll\x00"):], queryBson)
copy(msgBytes[20+len("testdb\x00testcoll\x00")+len(queryBson):], updateBson)
// 创建mock连接
conn := &MockConn{
readBuf: *bytes.NewBuffer(msgBytes),
}
// 创建mock存储
storage := &MockStorage{
UpdateFunc: func(db string, collection string, query []byte, update []byte) error {
// 验证参数
assert.Equal(t, "testdb", db)
assert.Equal(t, "testcoll", collection)
// assert.Equal(t, queryBson, query)
// assert.Equal(t, updateBson, update)
return nil
},
CreateCollFunc: func(db string, name string) error {
return nil
},
CreateDBFunc: func(db string) error {
return nil
},
DropCollFunc: func(db string, name string) error {
return nil
},
}
// 创建服务器
s := NewServer(storage)
// 处理连接
s.handleConnection(conn)
// 验证响应(这里只是一个空响应)
assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节
}
func TestHandleConnection_ParseError(t *testing.T) {
// 创建包含无效数据的连接
conn := &MockConn{
readBuf: *bytes.NewBuffer([]byte{0x01, 0x00, 0x00, 0x00}), // 无效的BSON数据
}
// 创建服务器
s := &Server{}
// 处理连接(这会导致解析错误)
s.handleConnection(conn)
// 这里没有断言,因为测试主要验证程序不会崩溃
}
// TestHandleConnection_CreateSuccess 测试创建数据库功能
func TestHandleConnection_CreateSuccess(t *testing.T) {
// 构建测试消息字节流(创建数据库)
msgBytes := make([]byte, 16+4+len("testdbx00"))
binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes)))
binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_CREATE_DB))
copy(msgBytes[16:], []byte("testdbx00"))
// 创建mock连接
conn := &MockConn{
readBuf: *bytes.NewBuffer(msgBytes),
}
// 创建mock存储
storage := &MockStorage{
CreateFunc: func(name string) error {
// 验证参数
assert.Equal(t, "testdb", name)
return nil
},
}
// 创建服务器
s := NewServer(storage)
// 处理连接
s.handleConnection(conn)
// 验证响应(这里只是一个空响应)
assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节
}
// TestHandleConnection_CreateCollectionSuccess 测试创建集合功能
func TestHandleConnection_CreateCollectionSuccess(t *testing.T) {
// 构建测试消息字节流(在数据库中创建集合)
msgBytes := make([]byte, 16+4+len("testdbx00testcollx00"))
binary.LittleEndian.PutUint32(msgBytes[0:4], uint32(len(msgBytes)))
binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_CREATE_COLL))
copy(msgBytes[16:], []byte("testdbx00testcollx00"))
// 创建mock连接
conn := &MockConn{
readBuf: *bytes.NewBuffer(msgBytes),
}
// 创建mock存储
storage := &MockStorage{
CreateCollFunc: func(db string, name string) error {
// 验证参数
assert.Equal(t, "testdb", db)
assert.Equal(t, "testcoll", name)
return nil
},
}
// 创建服务器
s := NewServer(storage)
// 处理连接
s.handleConnection(conn)
// 验证响应(这里只是一个空响应)
assert.Equal(t, 4, conn.writeBuf.Len()) // 只包含消息长度的4字节
}
func TestHandleConnection_UnsupportedOpcode(t *testing.T) {
// 构建包含不支持操作码的消息
msgBytes := make([]byte, 16)
binary.LittleEndian.PutUint32(msgBytes[0:4], 16)
binary.LittleEndian.PutUint32(msgBytes[12:16], uint32(protocol.OP_INSERT)) // OP_INSERT未实现
// 创建mock连接
conn := &MockConn{
readBuf: *bytes.NewBuffer(msgBytes),
}
// 创建服务器
s := &Server{}
// 处理连接
s.handleConnection(conn)
// 这里没有断言,因为测试主要验证程序不会崩溃
}