goaidb/protocol/parser_test.go

154 lines
3.4 KiB
Go
Raw Normal View History

2025-06-06 22:12:24 +08:00
package protocol
import (
"testing"
)
func TestParseUpdate(t *testing.T) {
// 构造测试数据最小有效Update消息
data := []byte{
// Header
2025-06-06 22:54:22 +08:00
16, 0, 0, 0, // MessageLength (16 bytes)
1, 0, 0, 0, // RequestID
0, 0, 0, 0, // ResponseTo
209, 7, 0, 0, // OpCode (OP_UPDATE)
2025-06-06 22:12:24 +08:00
// UpdateFlags (Upsert=1)
1, 0, 0, 0,
2025-06-06 22:54:22 +08:00
2025-06-06 22:12:24 +08:00
// DatabaseName "test\x00collection\x00"
't', 'e', 's', 't', 0,
2025-06-06 22:54:22 +08:00
2025-06-06 22:12:24 +08:00
// CollectionName "coll\x00"
'c', 'o', 'l', 'l', 0,
2025-06-06 22:54:22 +08:00
2025-06-06 22:12:24 +08:00
// Query文档最小BSON文档
5, 0, 0, 0, 0, // BSON长度 + 空文档
2025-06-06 22:54:22 +08:00
2025-06-06 22:12:24 +08:00
// Update文档
5, 0, 0, 0, 0, // BSON长度 + 空文档
}
msg, err := ParseMessage(data)
if err != nil {
t.Fatalf("ParseMessage failed: %v", err)
}
if msg.OpCode != OP_UPDATE {
t.Errorf("Expected OP_UPDATE, got %d", msg.OpCode)
}
updateMsg, ok := msg.Body.(*UpdateMessage)
if !ok {
t.Fatalf("Expected UpdateMessage, got %T", msg.Body)
}
2025-06-06 22:54:22 +08:00
if updateMsg.Body.Flags != Upsert {
t.Errorf("Expected Upsert flag, got %v", updateMsg.Body.Flags)
2025-06-06 22:12:24 +08:00
}
2025-06-06 22:54:22 +08:00
if updateMsg.Body.DatabaseName != "test" {
t.Errorf("Expected database 'test', got '%s'", updateMsg.Body.DatabaseName)
2025-06-06 22:12:24 +08:00
}
2025-06-06 22:54:22 +08:00
if updateMsg.Body.CollName != "coll" {
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.Body.CollName)
2025-06-06 22:12:24 +08:00
}
}
func TestParseInvalidUpdate(t *testing.T) {
// 测试短数据包
2025-06-06 22:54:22 +08:00
shortData := []byte{1, 0, 0, 0} // 长度不足4字节
2025-06-06 22:12:24 +08:00
_, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值
if err == nil {
t.Error("Expected error for short data")
}
2025-06-06 22:54:22 +08:00
}
func TestUpdateMessage_Functionality(t *testing.T) {
// 创建测试文档
testDoc := map[string]interface{}{
"name": "Alice",
"age": int32(30),
}
// 创建UpdateMessage实例
updateMsg := &UpdateMessage{
Body: struct {
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}{
UpdateSpec: map[string]interface{}{
"$set": map[string]interface{}{
"name": "Bob",
},
"$inc": map[string]interface{}{
"age": int32(1),
},
},
},
}
// 执行更新操作
updatedDoc, err := updateMsg.Update(testDoc)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
// 验证结果
if name, ok := updatedDoc["name"].(string); !ok || name != "Bob" {
t.Errorf("Expected name 'Bob', got '%v'", name)
}
if age, ok := updatedDoc["age"].(int32); !ok || age != 31 {
t.Errorf("Expected age 31, got %v", age)
}
}
func TestMultiUpdate(t *testing.T) {
// 创建测试文档集
docs := []map[string]interface{}{
{"name": "Alice", "age": int32(30)},
{"name": "Charlie", "age": int32(25)},
}
// 创建支持多更新的UpdateMessage
updateMsg := &UpdateMessage{
Body: struct {
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}{
Flags: MultiUpdate,
UpdateSpec: map[string]interface{}{
"$inc": map[string]interface{}{
"age": int32(2),
},
},
},
}
// 执行多文档更新
for i := range docs {
updatedDoc, err := updateMsg.Update(docs[i])
if err != nil {
t.Fatalf("Update failed for doc %d: %v", i, err)
}
docs[i] = updatedDoc
}
// 验证结果
expectedAges := []int32{32, 27}
for i, doc := range docs {
if age, ok := doc["age"].(int32); !ok || age != expectedAges[i] {
t.Errorf("Doc %d: expected age %d, got %v", i, expectedAges[i], age)
}
}
}