168 lines
3.7 KiB
Go
168 lines
3.7 KiB
Go
package protocol
|
||
|
||
import (
|
||
"testing"
|
||
|
||
"git.pyer.club/kingecg/goaidb/log"
|
||
)
|
||
|
||
func TestParseQuery(t *testing.T) {
|
||
_, data, err := log.ReadOpMsgFromFile("/home/kingecg/code/goaidb/opbin.log")
|
||
if err != nil {
|
||
t.Fatalf("ReadOpMsgFromFile failed: %v", err)
|
||
}
|
||
msg, err := ParseMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("ParseMessage failed: %v", err)
|
||
}
|
||
log.Info(msg.Body)
|
||
}
|
||
|
||
func TestParseUpdate(t *testing.T) {
|
||
// 构造测试数据(最小有效Update消息)
|
||
data := []byte{
|
||
// Header
|
||
16, 0, 0, 0, // MessageLength (16 bytes)
|
||
1, 0, 0, 0, // RequestID
|
||
0, 0, 0, 0, // ResponseTo
|
||
209, 7, 0, 0, // OpCode (OP_UPDATE)
|
||
|
||
// UpdateFlags (Upsert=1)
|
||
1, 0, 0, 0,
|
||
|
||
// DatabaseName "test\x00collection\x00"
|
||
't', 'e', 's', 't', 0,
|
||
|
||
// CollectionName "coll\x00"
|
||
'c', 'o', 'l', 'l', 0,
|
||
|
||
// Query文档(最小BSON文档)
|
||
5, 0, 0, 0, 0, // BSON长度 + 空文档
|
||
|
||
// Update文档
|
||
5, 0, 0, 0, 0, // BSON长度 + 空文档
|
||
}
|
||
|
||
msg, err := ParseMessage(data)
|
||
if err != nil {
|
||
t.Fatalf("ParseMessage failed: %v", err)
|
||
}
|
||
|
||
if msg.OpCode != OP_UPDATE {
|
||
t.Errorf("Expected OP_UPDATE, got %d", msg.OpCode)
|
||
}
|
||
|
||
updateMsg, ok := msg.Body.(*UpdateMessage)
|
||
if !ok {
|
||
t.Fatalf("Expected UpdateMessage, got %T", msg.Body)
|
||
}
|
||
|
||
if updateMsg.Body.Flags != Upsert {
|
||
t.Errorf("Expected Upsert flag, got %v", updateMsg.Body.Flags)
|
||
}
|
||
|
||
if updateMsg.Body.DatabaseName != "test" {
|
||
t.Errorf("Expected database 'test', got '%s'", updateMsg.Body.DatabaseName)
|
||
}
|
||
|
||
if updateMsg.Body.CollName != "coll" {
|
||
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.Body.CollName)
|
||
}
|
||
}
|
||
|
||
func TestParseInvalidUpdate(t *testing.T) {
|
||
// 测试短数据包
|
||
shortData := []byte{1, 0, 0, 0} // 长度不足4字节
|
||
_, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值
|
||
if err == nil {
|
||
t.Error("Expected error for short data")
|
||
}
|
||
}
|
||
|
||
func TestUpdateMessage_Functionality(t *testing.T) {
|
||
// 创建测试文档
|
||
testDoc := map[string]interface{}{
|
||
"name": "Alice",
|
||
"age": int32(30),
|
||
}
|
||
|
||
// 创建UpdateMessage实例
|
||
updateMsg := &UpdateMessage{
|
||
Body: struct {
|
||
Flags UpdateFlags
|
||
DatabaseName string
|
||
CollName string
|
||
Query map[string]interface{}
|
||
UpdateSpec map[string]interface{}
|
||
}{
|
||
UpdateSpec: map[string]interface{}{
|
||
"$set": map[string]interface{}{
|
||
"name": "Bob",
|
||
},
|
||
"$inc": map[string]interface{}{
|
||
"age": int32(1),
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
// 执行更新操作
|
||
updatedDoc, err := updateMsg.Update(testDoc)
|
||
if err != nil {
|
||
t.Fatalf("Update failed: %v", err)
|
||
}
|
||
|
||
// 验证结果
|
||
if name, ok := updatedDoc["name"].(string); !ok || name != "Bob" {
|
||
t.Errorf("Expected name 'Bob', got '%v'", name)
|
||
}
|
||
|
||
if age, ok := updatedDoc["age"].(int32); !ok || age != 31 {
|
||
t.Errorf("Expected age 31, got %v", age)
|
||
}
|
||
}
|
||
|
||
func TestMultiUpdate(t *testing.T) {
|
||
// 创建测试文档集
|
||
docs := []map[string]interface{}{
|
||
{"name": "Alice", "age": int32(30)},
|
||
{"name": "Charlie", "age": int32(25)},
|
||
}
|
||
|
||
// 创建支持多更新的UpdateMessage
|
||
updateMsg := &UpdateMessage{
|
||
Body: struct {
|
||
Flags UpdateFlags
|
||
DatabaseName string
|
||
CollName string
|
||
Query map[string]interface{}
|
||
UpdateSpec map[string]interface{}
|
||
}{
|
||
Flags: MultiUpdate,
|
||
UpdateSpec: map[string]interface{}{
|
||
"$inc": map[string]interface{}{
|
||
"age": int32(2),
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
// 执行多文档更新
|
||
for i := range docs {
|
||
updatedDoc, err := updateMsg.Update(docs[i])
|
||
if err != nil {
|
||
t.Fatalf("Update failed for doc %d: %v", i, err)
|
||
}
|
||
|
||
docs[i] = updatedDoc
|
||
}
|
||
|
||
// 验证结果
|
||
expectedAges := []int32{32, 27}
|
||
for i, doc := range docs {
|
||
if age, ok := doc["age"].(int32); !ok || age != expectedAges[i] {
|
||
t.Errorf("Doc %d: expected age %d, got %v", i, expectedAges[i], age)
|
||
}
|
||
}
|
||
}
|