diff --git a/network/server.go b/network/server.go index 5446d1b..aaed579 100644 --- a/network/server.go +++ b/network/server.go @@ -60,19 +60,19 @@ func (s *Server) handleConnection(conn net.Conn) { case protocol.OP_UPDATE: updateMsg := message.Body.(*protocol.UpdateMessage) // 序列化查询和更新文档为BSON格式 - queryBson, err := protocol.BsonMarshal(updateMsg.Query) + queryBson, err := protocol.BsonMarshal(updateMsg.Body.Query) if err != nil { log.Error("查询文档序列化失败", "error", err) continue } - updateBson, err := protocol.BsonMarshal(updateMsg.Update) + updateBson, err := protocol.BsonMarshal(updateMsg.Body.UpdateSpec) if err != nil { log.Error("更新文档序列化失败", "error", err) continue } - err = s.storage.Update(updateMsg.DatabaseName, updateMsg.CollName, queryBson, updateBson) + err = s.storage.Update(updateMsg.Body.DatabaseName, updateMsg.Body.CollName, queryBson, updateBson) if err != nil { log.Error("存储层更新失败", "error", err) continue diff --git a/protocol/bson.go b/protocol/bson.go index 6905d74..419dcf8 100644 --- a/protocol/bson.go +++ b/protocol/bson.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "math" + "strings" ) // BSONElement represents a single BSON element @@ -67,7 +68,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) { pos = keyEnd + 1 // 根据元素类型解析值 - value, newPos, err := parseBSONValue(elementType, data[pos:]) + value, newPos, err := parseBSONValue(elementType, data[pos:], 0) if err != nil { return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err) } @@ -80,7 +81,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) { } // parseBSONValue 解析特定类型的BSON值 -func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) { +func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, error) { switch elementType { case 0x10: // Int32 if len(data) < 4 { @@ -121,7 +122,7 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) { } // 读取字符串内容(忽略最后的终止符) - value := string(data[4 : 4+strLength]) + value := strings.Trim(string(data[4:4+strLength]), "\x00") return value, 4 + strLength, nil case 0x08: // Boolean @@ -146,6 +147,26 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) { return nil, 0, fmt.Errorf("data too short for EmbeddedDocument") } + // 检查是否为更新操作符(以$开头的键名) + if pos > 0 && data[pos-1] == 0x02 { // 前一个字节是字符串类型标记 + // 查找前一个键名 + for i := pos - 2; i >= 0; i-- { + if data[i] == 0x02 { // 找到字符串类型标记 + keyLen := int(binary.LittleEndian.Uint32(data[i+1 : i+5])) + keyStart := i + 5 + if keyStart+keyLen <= pos { + key := string(data[keyStart : keyStart+keyLen-1]) + if len(key) > 0 && key[0] == '$' { + // 这是一个更新操作符 + value := make(map[string]interface{}) + value[key] = "operator_placeholder" + return value, docLength, nil + } + } + } + } + } + // 解析嵌入文档内容 subDoc, _, err := parseBSON(data[0:docLength]) if err != nil { diff --git a/protocol/bson_test.go b/protocol/bson_test.go index 7e39dd9..f69d3aa 100644 --- a/protocol/bson_test.go +++ b/protocol/bson_test.go @@ -75,6 +75,65 @@ func TestParseBSON_String(t *testing.T) { } } +func TestParseBSON_UpdateOperators(t *testing.T) { + // 测试$set操作符解析 + setData := []byte{ + 0x1f, 0, 0, 0, // 文档长度=21字节 + 0x03, // 嵌入文档类型 + '$', 's', 'e', 't', 0x00, // $set操作符键名 + 0x14, 0, 0, 0, // 子文档长度 + 0x02, // 字符串类型 + 'n', 'a', 'm', 'e', 0x00, // 键名"name" + 5, 0, 0, 0, // 字符串长度=6字节(含终止符) + 'j', 'o', 'h', 'n', 0x00, // 字符串内容 + 0x00, // 结束符 + 0x00, // 主文档结束符 + } + + result, err := ParseBSON(setData) + if err != nil { + t.Fatalf("ParseBSON failed: %v", err) + } + + value, ok := result["$set"] + if !ok { + t.Error("Expected key '$set' not found") + return + } + + if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["name"] != "john" { + t.Errorf("Expected $set{name: 'john'}, got %v", value) + } + + // 测试$inc操作符解析 + incData := []byte{ + 0x19, 0, 0, 0, // 文档长度=21字节 + 0x03, // 嵌入文档类型 + '$', 'i', 'n', 'c', 0x00, // $inc操作符键名 + 0x0e, 0, 0, 0, // 子文档长度 + 0x10, // Int32类型 + 'a', 'g', 'e', 0x00, // 键名"age" + 22, 0, 0, 0, // 值=22 + 0x00, // 结束符 + 0x00, // 主文档结束符 + } + + result, err = ParseBSON(incData) + if err != nil { + t.Fatalf("ParseBSON failed: %v", err) + } + + value, ok = result["$inc"] + if !ok { + t.Error("Expected key '$inc' not found") + return + } + + if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["age"] != int32(22) { + t.Errorf("Expected $inc{age: 22}, got %v", value) + } +} + func TestParseBSON_ErrorCases(t *testing.T) { // 测试数据过短的情况 shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档(仅包含长度字段) diff --git a/protocol/parser.go b/protocol/parser.go index efb0725..8404a15 100644 --- a/protocol/parser.go +++ b/protocol/parser.go @@ -7,13 +7,10 @@ import ( "fmt" ) -// UpdateFlags 更新操作标志位 - +// UpdateFlags are the flags for OP_UPDATE const ( - // Update操作的标志位常量 - Upsert = 1 << 0 - MultiUpdate = 1 << 1 - WriteConcern = 1 << 3 // 3.x驱动已弃用 + Upsert = 1 << iota + MultiUpdate // 标志位用于多文档更新 ) // Header 消息头 @@ -99,11 +96,19 @@ func parseUpdate(data []byte) (interface{}, error) { } return &UpdateMessage{ - Flags: flags, - DatabaseName: dbName, - CollName: collName, - Query: queryDoc, - Update: updateDoc, + Body: struct { + Flags UpdateFlags + DatabaseName string + CollName string + Query map[string]interface{} + UpdateSpec map[string]interface{} + }{ + Flags: flags, + DatabaseName: dbName, + CollName: collName, + Query: queryDoc, + UpdateSpec: updateDoc, + }, }, nil } @@ -274,11 +279,50 @@ type InsertMessage struct { Documents []map[string]interface{} // 要插入的文档 } -// UpdateMessage OP_UPDATE消息体结构 +// UpdateMessage represents an OP_UPDATE message type UpdateMessage struct { - Flags UpdateFlags // 更新标志 - DatabaseName string // 数据库名称 - CollName string // 集合名称 - Query map[string]interface{} // 查询条件 - Update map[string]interface{} // 更新操作 + Header Header + Body struct { + Flags UpdateFlags + DatabaseName string + CollName string + Query map[string]interface{} + UpdateSpec map[string]interface{} + } +} + +// Update applies the update operation to a document +func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) { + result := make(map[string]interface{}) + for key, value := range doc { + result[key] = value + } + + // Process $set operator + if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok { + for key, value := range setOp { + result[key] = value + } + } + + // Process $inc operator + if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok { + for key, value := range incOp { + if num, ok := value.(int32); ok { + if current, exists := result[key]; exists { + if currentNum, ok := current.(int32); ok { + result[key] = currentNum + num + } else { + return nil, fmt.Errorf("cannot apply $inc to non-integer field") + } + } else { + result[key] = num + } + } else { + return nil, fmt.Errorf("invalid increment value") + } + } + } + + return result, nil } diff --git a/protocol/parser_test.go b/protocol/parser_test.go index 9346ef6..a5ec82f 100644 --- a/protocol/parser_test.go +++ b/protocol/parser_test.go @@ -8,23 +8,23 @@ func TestParseUpdate(t *testing.T) { // 构造测试数据(最小有效Update消息) data := []byte{ // Header - 16, 0, 0, 0, // MessageLength (16 bytes) - 1, 0, 0, 0, // RequestID - 0, 0, 0, 0, // ResponseTo - 209, 7, 0, 0, // OpCode (OP_UPDATE) - + 16, 0, 0, 0, // MessageLength (16 bytes) + 1, 0, 0, 0, // RequestID + 0, 0, 0, 0, // ResponseTo + 209, 7, 0, 0, // OpCode (OP_UPDATE) + // UpdateFlags (Upsert=1) 1, 0, 0, 0, - + // DatabaseName "test\x00collection\x00" 't', 'e', 's', 't', 0, - + // CollectionName "coll\x00" 'c', 'o', 'l', 'l', 0, - + // Query文档(最小BSON文档) 5, 0, 0, 0, 0, // BSON长度 + 空文档 - + // Update文档 5, 0, 0, 0, 0, // BSON长度 + 空文档 } @@ -43,24 +43,111 @@ func TestParseUpdate(t *testing.T) { t.Fatalf("Expected UpdateMessage, got %T", msg.Body) } - if updateMsg.Flags != Upsert { - t.Errorf("Expected Upsert flag, got %v", updateMsg.Flags) + if updateMsg.Body.Flags != Upsert { + t.Errorf("Expected Upsert flag, got %v", updateMsg.Body.Flags) } - if updateMsg.DatabaseName != "test" { - t.Errorf("Expected database 'test', got '%s'", updateMsg.DatabaseName) + if updateMsg.Body.DatabaseName != "test" { + t.Errorf("Expected database 'test', got '%s'", updateMsg.Body.DatabaseName) } - if updateMsg.CollName != "coll" { - t.Errorf("Expected collection 'coll', got '%s'", updateMsg.CollName) + if updateMsg.Body.CollName != "coll" { + t.Errorf("Expected collection 'coll', got '%s'", updateMsg.Body.CollName) } } func TestParseInvalidUpdate(t *testing.T) { // 测试短数据包 - shortData := []byte{1, 0, 0, 0} // 长度不足4字节 + shortData := []byte{1, 0, 0, 0} // 长度不足4字节 _, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值 if err == nil { t.Error("Expected error for short data") } -} \ No newline at end of file +} + +func TestUpdateMessage_Functionality(t *testing.T) { + // 创建测试文档 + testDoc := map[string]interface{}{ + "name": "Alice", + "age": int32(30), + } + + // 创建UpdateMessage实例 + updateMsg := &UpdateMessage{ + Body: struct { + Flags UpdateFlags + DatabaseName string + CollName string + Query map[string]interface{} + UpdateSpec map[string]interface{} + }{ + UpdateSpec: map[string]interface{}{ + "$set": map[string]interface{}{ + "name": "Bob", + }, + "$inc": map[string]interface{}{ + "age": int32(1), + }, + }, + }, + } + + // 执行更新操作 + updatedDoc, err := updateMsg.Update(testDoc) + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + // 验证结果 + if name, ok := updatedDoc["name"].(string); !ok || name != "Bob" { + t.Errorf("Expected name 'Bob', got '%v'", name) + } + + if age, ok := updatedDoc["age"].(int32); !ok || age != 31 { + t.Errorf("Expected age 31, got %v", age) + } +} + +func TestMultiUpdate(t *testing.T) { + // 创建测试文档集 + docs := []map[string]interface{}{ + {"name": "Alice", "age": int32(30)}, + {"name": "Charlie", "age": int32(25)}, + } + + // 创建支持多更新的UpdateMessage + updateMsg := &UpdateMessage{ + Body: struct { + Flags UpdateFlags + DatabaseName string + CollName string + Query map[string]interface{} + UpdateSpec map[string]interface{} + }{ + Flags: MultiUpdate, + UpdateSpec: map[string]interface{}{ + "$inc": map[string]interface{}{ + "age": int32(2), + }, + }, + }, + } + + // 执行多文档更新 + for i := range docs { + updatedDoc, err := updateMsg.Update(docs[i]) + if err != nil { + t.Fatalf("Update failed for doc %d: %v", i, err) + } + + docs[i] = updatedDoc + } + + // 验证结果 + expectedAges := []int32{32, 27} + for i, doc := range docs { + if age, ok := doc["age"].(int32); !ok || age != expectedAges[i] { + t.Errorf("Doc %d: expected age %d, got %v", i, expectedAges[i], age) + } + } +}