add $inc and $set

This commit is contained in:
kingecg 2025-06-06 22:54:22 +08:00
parent 9c08d68291
commit 8d25d60018
5 changed files with 251 additions and 40 deletions

View File

@ -60,19 +60,19 @@ func (s *Server) handleConnection(conn net.Conn) {
case protocol.OP_UPDATE: case protocol.OP_UPDATE:
updateMsg := message.Body.(*protocol.UpdateMessage) updateMsg := message.Body.(*protocol.UpdateMessage)
// 序列化查询和更新文档为BSON格式 // 序列化查询和更新文档为BSON格式
queryBson, err := protocol.BsonMarshal(updateMsg.Query) queryBson, err := protocol.BsonMarshal(updateMsg.Body.Query)
if err != nil { if err != nil {
log.Error("查询文档序列化失败", "error", err) log.Error("查询文档序列化失败", "error", err)
continue continue
} }
updateBson, err := protocol.BsonMarshal(updateMsg.Update) updateBson, err := protocol.BsonMarshal(updateMsg.Body.UpdateSpec)
if err != nil { if err != nil {
log.Error("更新文档序列化失败", "error", err) log.Error("更新文档序列化失败", "error", err)
continue continue
} }
err = s.storage.Update(updateMsg.DatabaseName, updateMsg.CollName, queryBson, updateBson) err = s.storage.Update(updateMsg.Body.DatabaseName, updateMsg.Body.CollName, queryBson, updateBson)
if err != nil { if err != nil {
log.Error("存储层更新失败", "error", err) log.Error("存储层更新失败", "error", err)
continue continue

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math" "math"
"strings"
) )
// BSONElement represents a single BSON element // BSONElement represents a single BSON element
@ -67,7 +68,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
pos = keyEnd + 1 pos = keyEnd + 1
// 根据元素类型解析值 // 根据元素类型解析值
value, newPos, err := parseBSONValue(elementType, data[pos:]) value, newPos, err := parseBSONValue(elementType, data[pos:], 0)
if err != nil { if err != nil {
return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err) return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err)
} }
@ -80,7 +81,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
} }
// parseBSONValue 解析特定类型的BSON值 // parseBSONValue 解析特定类型的BSON值
func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) { func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, error) {
switch elementType { switch elementType {
case 0x10: // Int32 case 0x10: // Int32
if len(data) < 4 { if len(data) < 4 {
@ -121,7 +122,7 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
} }
// 读取字符串内容(忽略最后的终止符) // 读取字符串内容(忽略最后的终止符)
value := string(data[4 : 4+strLength]) value := strings.Trim(string(data[4:4+strLength]), "\x00")
return value, 4 + strLength, nil return value, 4 + strLength, nil
case 0x08: // Boolean case 0x08: // Boolean
@ -146,6 +147,26 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
return nil, 0, fmt.Errorf("data too short for EmbeddedDocument") return nil, 0, fmt.Errorf("data too short for EmbeddedDocument")
} }
// 检查是否为更新操作符(以$开头的键名)
if pos > 0 && data[pos-1] == 0x02 { // 前一个字节是字符串类型标记
// 查找前一个键名
for i := pos - 2; i >= 0; i-- {
if data[i] == 0x02 { // 找到字符串类型标记
keyLen := int(binary.LittleEndian.Uint32(data[i+1 : i+5]))
keyStart := i + 5
if keyStart+keyLen <= pos {
key := string(data[keyStart : keyStart+keyLen-1])
if len(key) > 0 && key[0] == '$' {
// 这是一个更新操作符
value := make(map[string]interface{})
value[key] = "operator_placeholder"
return value, docLength, nil
}
}
}
}
}
// 解析嵌入文档内容 // 解析嵌入文档内容
subDoc, _, err := parseBSON(data[0:docLength]) subDoc, _, err := parseBSON(data[0:docLength])
if err != nil { if err != nil {

View File

@ -75,6 +75,65 @@ func TestParseBSON_String(t *testing.T) {
} }
} }
func TestParseBSON_UpdateOperators(t *testing.T) {
// 测试$set操作符解析
setData := []byte{
0x1f, 0, 0, 0, // 文档长度=21字节
0x03, // 嵌入文档类型
'$', 's', 'e', 't', 0x00, // $set操作符键名
0x14, 0, 0, 0, // 子文档长度
0x02, // 字符串类型
'n', 'a', 'm', 'e', 0x00, // 键名"name"
5, 0, 0, 0, // 字符串长度=6字节含终止符
'j', 'o', 'h', 'n', 0x00, // 字符串内容
0x00, // 结束符
0x00, // 主文档结束符
}
result, err := ParseBSON(setData)
if err != nil {
t.Fatalf("ParseBSON failed: %v", err)
}
value, ok := result["$set"]
if !ok {
t.Error("Expected key '$set' not found")
return
}
if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["name"] != "john" {
t.Errorf("Expected $set{name: 'john'}, got %v", value)
}
// 测试$inc操作符解析
incData := []byte{
0x19, 0, 0, 0, // 文档长度=21字节
0x03, // 嵌入文档类型
'$', 'i', 'n', 'c', 0x00, // $inc操作符键名
0x0e, 0, 0, 0, // 子文档长度
0x10, // Int32类型
'a', 'g', 'e', 0x00, // 键名"age"
22, 0, 0, 0, // 值=22
0x00, // 结束符
0x00, // 主文档结束符
}
result, err = ParseBSON(incData)
if err != nil {
t.Fatalf("ParseBSON failed: %v", err)
}
value, ok = result["$inc"]
if !ok {
t.Error("Expected key '$inc' not found")
return
}
if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["age"] != int32(22) {
t.Errorf("Expected $inc{age: 22}, got %v", value)
}
}
func TestParseBSON_ErrorCases(t *testing.T) { func TestParseBSON_ErrorCases(t *testing.T) {
// 测试数据过短的情况 // 测试数据过短的情况
shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档仅包含长度字段 shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档仅包含长度字段

View File

@ -7,13 +7,10 @@ import (
"fmt" "fmt"
) )
// UpdateFlags 更新操作标志位 // UpdateFlags are the flags for OP_UPDATE
const ( const (
// Update操作的标志位常量 Upsert = 1 << iota
Upsert = 1 << 0 MultiUpdate // 标志位用于多文档更新
MultiUpdate = 1 << 1
WriteConcern = 1 << 3 // 3.x驱动已弃用
) )
// Header 消息头 // Header 消息头
@ -99,11 +96,19 @@ func parseUpdate(data []byte) (interface{}, error) {
} }
return &UpdateMessage{ return &UpdateMessage{
Flags: flags, Body: struct {
DatabaseName: dbName, Flags UpdateFlags
CollName: collName, DatabaseName string
Query: queryDoc, CollName string
Update: updateDoc, Query map[string]interface{}
UpdateSpec map[string]interface{}
}{
Flags: flags,
DatabaseName: dbName,
CollName: collName,
Query: queryDoc,
UpdateSpec: updateDoc,
},
}, nil }, nil
} }
@ -274,11 +279,50 @@ type InsertMessage struct {
Documents []map[string]interface{} // 要插入的文档 Documents []map[string]interface{} // 要插入的文档
} }
// UpdateMessage OP_UPDATE消息体结构 // UpdateMessage represents an OP_UPDATE message
type UpdateMessage struct { type UpdateMessage struct {
Flags UpdateFlags // 更新标志 Header Header
DatabaseName string // 数据库名称 Body struct {
CollName string // 集合名称 Flags UpdateFlags
Query map[string]interface{} // 查询条件 DatabaseName string
Update map[string]interface{} // 更新操作 CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}
}
// Update applies the update operation to a document
func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) {
result := make(map[string]interface{})
for key, value := range doc {
result[key] = value
}
// Process $set operator
if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok {
for key, value := range setOp {
result[key] = value
}
}
// Process $inc operator
if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok {
for key, value := range incOp {
if num, ok := value.(int32); ok {
if current, exists := result[key]; exists {
if currentNum, ok := current.(int32); ok {
result[key] = currentNum + num
} else {
return nil, fmt.Errorf("cannot apply $inc to non-integer field")
}
} else {
result[key] = num
}
} else {
return nil, fmt.Errorf("invalid increment value")
}
}
}
return result, nil
} }

View File

@ -8,10 +8,10 @@ func TestParseUpdate(t *testing.T) {
// 构造测试数据最小有效Update消息 // 构造测试数据最小有效Update消息
data := []byte{ data := []byte{
// Header // Header
16, 0, 0, 0, // MessageLength (16 bytes) 16, 0, 0, 0, // MessageLength (16 bytes)
1, 0, 0, 0, // RequestID 1, 0, 0, 0, // RequestID
0, 0, 0, 0, // ResponseTo 0, 0, 0, 0, // ResponseTo
209, 7, 0, 0, // OpCode (OP_UPDATE) 209, 7, 0, 0, // OpCode (OP_UPDATE)
// UpdateFlags (Upsert=1) // UpdateFlags (Upsert=1)
1, 0, 0, 0, 1, 0, 0, 0,
@ -43,24 +43,111 @@ func TestParseUpdate(t *testing.T) {
t.Fatalf("Expected UpdateMessage, got %T", msg.Body) t.Fatalf("Expected UpdateMessage, got %T", msg.Body)
} }
if updateMsg.Flags != Upsert { if updateMsg.Body.Flags != Upsert {
t.Errorf("Expected Upsert flag, got %v", updateMsg.Flags) t.Errorf("Expected Upsert flag, got %v", updateMsg.Body.Flags)
} }
if updateMsg.DatabaseName != "test" { if updateMsg.Body.DatabaseName != "test" {
t.Errorf("Expected database 'test', got '%s'", updateMsg.DatabaseName) t.Errorf("Expected database 'test', got '%s'", updateMsg.Body.DatabaseName)
} }
if updateMsg.CollName != "coll" { if updateMsg.Body.CollName != "coll" {
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.CollName) t.Errorf("Expected collection 'coll', got '%s'", updateMsg.Body.CollName)
} }
} }
func TestParseInvalidUpdate(t *testing.T) { func TestParseInvalidUpdate(t *testing.T) {
// 测试短数据包 // 测试短数据包
shortData := []byte{1, 0, 0, 0} // 长度不足4字节 shortData := []byte{1, 0, 0, 0} // 长度不足4字节
_, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值 _, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值
if err == nil { if err == nil {
t.Error("Expected error for short data") t.Error("Expected error for short data")
} }
} }
func TestUpdateMessage_Functionality(t *testing.T) {
// 创建测试文档
testDoc := map[string]interface{}{
"name": "Alice",
"age": int32(30),
}
// 创建UpdateMessage实例
updateMsg := &UpdateMessage{
Body: struct {
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}{
UpdateSpec: map[string]interface{}{
"$set": map[string]interface{}{
"name": "Bob",
},
"$inc": map[string]interface{}{
"age": int32(1),
},
},
},
}
// 执行更新操作
updatedDoc, err := updateMsg.Update(testDoc)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
// 验证结果
if name, ok := updatedDoc["name"].(string); !ok || name != "Bob" {
t.Errorf("Expected name 'Bob', got '%v'", name)
}
if age, ok := updatedDoc["age"].(int32); !ok || age != 31 {
t.Errorf("Expected age 31, got %v", age)
}
}
func TestMultiUpdate(t *testing.T) {
// 创建测试文档集
docs := []map[string]interface{}{
{"name": "Alice", "age": int32(30)},
{"name": "Charlie", "age": int32(25)},
}
// 创建支持多更新的UpdateMessage
updateMsg := &UpdateMessage{
Body: struct {
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}{
Flags: MultiUpdate,
UpdateSpec: map[string]interface{}{
"$inc": map[string]interface{}{
"age": int32(2),
},
},
},
}
// 执行多文档更新
for i := range docs {
updatedDoc, err := updateMsg.Update(docs[i])
if err != nil {
t.Fatalf("Update failed for doc %d: %v", i, err)
}
docs[i] = updatedDoc
}
// 验证结果
expectedAges := []int32{32, 27}
for i, doc := range docs {
if age, ok := doc["age"].(int32); !ok || age != expectedAges[i] {
t.Errorf("Doc %d: expected age %d, got %v", i, expectedAges[i], age)
}
}
}