add $inc and $set
This commit is contained in:
parent
9c08d68291
commit
8d25d60018
|
@ -60,19 +60,19 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||
case protocol.OP_UPDATE:
|
||||
updateMsg := message.Body.(*protocol.UpdateMessage)
|
||||
// 序列化查询和更新文档为BSON格式
|
||||
queryBson, err := protocol.BsonMarshal(updateMsg.Query)
|
||||
queryBson, err := protocol.BsonMarshal(updateMsg.Body.Query)
|
||||
if err != nil {
|
||||
log.Error("查询文档序列化失败", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
updateBson, err := protocol.BsonMarshal(updateMsg.Update)
|
||||
updateBson, err := protocol.BsonMarshal(updateMsg.Body.UpdateSpec)
|
||||
if err != nil {
|
||||
log.Error("更新文档序列化失败", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = s.storage.Update(updateMsg.DatabaseName, updateMsg.CollName, queryBson, updateBson)
|
||||
err = s.storage.Update(updateMsg.Body.DatabaseName, updateMsg.Body.CollName, queryBson, updateBson)
|
||||
if err != nil {
|
||||
log.Error("存储层更新失败", "error", err)
|
||||
continue
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BSONElement represents a single BSON element
|
||||
|
@ -67,7 +68,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
|
|||
pos = keyEnd + 1
|
||||
|
||||
// 根据元素类型解析值
|
||||
value, newPos, err := parseBSONValue(elementType, data[pos:])
|
||||
value, newPos, err := parseBSONValue(elementType, data[pos:], 0)
|
||||
if err != nil {
|
||||
return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err)
|
||||
}
|
||||
|
@ -80,7 +81,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
|
|||
}
|
||||
|
||||
// parseBSONValue 解析特定类型的BSON值
|
||||
func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
|
||||
func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, error) {
|
||||
switch elementType {
|
||||
case 0x10: // Int32
|
||||
if len(data) < 4 {
|
||||
|
@ -121,7 +122,7 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
|
|||
}
|
||||
|
||||
// 读取字符串内容(忽略最后的终止符)
|
||||
value := string(data[4 : 4+strLength])
|
||||
value := strings.Trim(string(data[4:4+strLength]), "\x00")
|
||||
return value, 4 + strLength, nil
|
||||
|
||||
case 0x08: // Boolean
|
||||
|
@ -146,6 +147,26 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
|
|||
return nil, 0, fmt.Errorf("data too short for EmbeddedDocument")
|
||||
}
|
||||
|
||||
// 检查是否为更新操作符(以$开头的键名)
|
||||
if pos > 0 && data[pos-1] == 0x02 { // 前一个字节是字符串类型标记
|
||||
// 查找前一个键名
|
||||
for i := pos - 2; i >= 0; i-- {
|
||||
if data[i] == 0x02 { // 找到字符串类型标记
|
||||
keyLen := int(binary.LittleEndian.Uint32(data[i+1 : i+5]))
|
||||
keyStart := i + 5
|
||||
if keyStart+keyLen <= pos {
|
||||
key := string(data[keyStart : keyStart+keyLen-1])
|
||||
if len(key) > 0 && key[0] == '$' {
|
||||
// 这是一个更新操作符
|
||||
value := make(map[string]interface{})
|
||||
value[key] = "operator_placeholder"
|
||||
return value, docLength, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析嵌入文档内容
|
||||
subDoc, _, err := parseBSON(data[0:docLength])
|
||||
if err != nil {
|
||||
|
|
|
@ -75,6 +75,65 @@ func TestParseBSON_String(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseBSON_UpdateOperators(t *testing.T) {
|
||||
// 测试$set操作符解析
|
||||
setData := []byte{
|
||||
0x1f, 0, 0, 0, // 文档长度=21字节
|
||||
0x03, // 嵌入文档类型
|
||||
'$', 's', 'e', 't', 0x00, // $set操作符键名
|
||||
0x14, 0, 0, 0, // 子文档长度
|
||||
0x02, // 字符串类型
|
||||
'n', 'a', 'm', 'e', 0x00, // 键名"name"
|
||||
5, 0, 0, 0, // 字符串长度=6字节(含终止符)
|
||||
'j', 'o', 'h', 'n', 0x00, // 字符串内容
|
||||
0x00, // 结束符
|
||||
0x00, // 主文档结束符
|
||||
}
|
||||
|
||||
result, err := ParseBSON(setData)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseBSON failed: %v", err)
|
||||
}
|
||||
|
||||
value, ok := result["$set"]
|
||||
if !ok {
|
||||
t.Error("Expected key '$set' not found")
|
||||
return
|
||||
}
|
||||
|
||||
if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["name"] != "john" {
|
||||
t.Errorf("Expected $set{name: 'john'}, got %v", value)
|
||||
}
|
||||
|
||||
// 测试$inc操作符解析
|
||||
incData := []byte{
|
||||
0x19, 0, 0, 0, // 文档长度=21字节
|
||||
0x03, // 嵌入文档类型
|
||||
'$', 'i', 'n', 'c', 0x00, // $inc操作符键名
|
||||
0x0e, 0, 0, 0, // 子文档长度
|
||||
0x10, // Int32类型
|
||||
'a', 'g', 'e', 0x00, // 键名"age"
|
||||
22, 0, 0, 0, // 值=22
|
||||
0x00, // 结束符
|
||||
0x00, // 主文档结束符
|
||||
}
|
||||
|
||||
result, err = ParseBSON(incData)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseBSON failed: %v", err)
|
||||
}
|
||||
|
||||
value, ok = result["$inc"]
|
||||
if !ok {
|
||||
t.Error("Expected key '$inc' not found")
|
||||
return
|
||||
}
|
||||
|
||||
if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["age"] != int32(22) {
|
||||
t.Errorf("Expected $inc{age: 22}, got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBSON_ErrorCases(t *testing.T) {
|
||||
// 测试数据过短的情况
|
||||
shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档(仅包含长度字段)
|
||||
|
|
|
@ -7,13 +7,10 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// UpdateFlags 更新操作标志位
|
||||
|
||||
// UpdateFlags are the flags for OP_UPDATE
|
||||
const (
|
||||
// Update操作的标志位常量
|
||||
Upsert = 1 << 0
|
||||
MultiUpdate = 1 << 1
|
||||
WriteConcern = 1 << 3 // 3.x驱动已弃用
|
||||
Upsert = 1 << iota
|
||||
MultiUpdate // 标志位用于多文档更新
|
||||
)
|
||||
|
||||
// Header 消息头
|
||||
|
@ -99,11 +96,19 @@ func parseUpdate(data []byte) (interface{}, error) {
|
|||
}
|
||||
|
||||
return &UpdateMessage{
|
||||
Flags: flags,
|
||||
DatabaseName: dbName,
|
||||
CollName: collName,
|
||||
Query: queryDoc,
|
||||
Update: updateDoc,
|
||||
Body: struct {
|
||||
Flags UpdateFlags
|
||||
DatabaseName string
|
||||
CollName string
|
||||
Query map[string]interface{}
|
||||
UpdateSpec map[string]interface{}
|
||||
}{
|
||||
Flags: flags,
|
||||
DatabaseName: dbName,
|
||||
CollName: collName,
|
||||
Query: queryDoc,
|
||||
UpdateSpec: updateDoc,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -274,11 +279,50 @@ type InsertMessage struct {
|
|||
Documents []map[string]interface{} // 要插入的文档
|
||||
}
|
||||
|
||||
// UpdateMessage OP_UPDATE消息体结构
|
||||
// UpdateMessage represents an OP_UPDATE message
|
||||
type UpdateMessage struct {
|
||||
Flags UpdateFlags // 更新标志
|
||||
DatabaseName string // 数据库名称
|
||||
CollName string // 集合名称
|
||||
Query map[string]interface{} // 查询条件
|
||||
Update map[string]interface{} // 更新操作
|
||||
Header Header
|
||||
Body struct {
|
||||
Flags UpdateFlags
|
||||
DatabaseName string
|
||||
CollName string
|
||||
Query map[string]interface{}
|
||||
UpdateSpec map[string]interface{}
|
||||
}
|
||||
}
|
||||
|
||||
// Update applies the update operation to a document
|
||||
func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
for key, value := range doc {
|
||||
result[key] = value
|
||||
}
|
||||
|
||||
// Process $set operator
|
||||
if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok {
|
||||
for key, value := range setOp {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Process $inc operator
|
||||
if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok {
|
||||
for key, value := range incOp {
|
||||
if num, ok := value.(int32); ok {
|
||||
if current, exists := result[key]; exists {
|
||||
if currentNum, ok := current.(int32); ok {
|
||||
result[key] = currentNum + num
|
||||
} else {
|
||||
return nil, fmt.Errorf("cannot apply $inc to non-integer field")
|
||||
}
|
||||
} else {
|
||||
result[key] = num
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid increment value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -8,23 +8,23 @@ func TestParseUpdate(t *testing.T) {
|
|||
// 构造测试数据(最小有效Update消息)
|
||||
data := []byte{
|
||||
// Header
|
||||
16, 0, 0, 0, // MessageLength (16 bytes)
|
||||
1, 0, 0, 0, // RequestID
|
||||
0, 0, 0, 0, // ResponseTo
|
||||
209, 7, 0, 0, // OpCode (OP_UPDATE)
|
||||
|
||||
16, 0, 0, 0, // MessageLength (16 bytes)
|
||||
1, 0, 0, 0, // RequestID
|
||||
0, 0, 0, 0, // ResponseTo
|
||||
209, 7, 0, 0, // OpCode (OP_UPDATE)
|
||||
|
||||
// UpdateFlags (Upsert=1)
|
||||
1, 0, 0, 0,
|
||||
|
||||
|
||||
// DatabaseName "test\x00collection\x00"
|
||||
't', 'e', 's', 't', 0,
|
||||
|
||||
|
||||
// CollectionName "coll\x00"
|
||||
'c', 'o', 'l', 'l', 0,
|
||||
|
||||
|
||||
// Query文档(最小BSON文档)
|
||||
5, 0, 0, 0, 0, // BSON长度 + 空文档
|
||||
|
||||
|
||||
// Update文档
|
||||
5, 0, 0, 0, 0, // BSON长度 + 空文档
|
||||
}
|
||||
|
@ -43,24 +43,111 @@ func TestParseUpdate(t *testing.T) {
|
|||
t.Fatalf("Expected UpdateMessage, got %T", msg.Body)
|
||||
}
|
||||
|
||||
if updateMsg.Flags != Upsert {
|
||||
t.Errorf("Expected Upsert flag, got %v", updateMsg.Flags)
|
||||
if updateMsg.Body.Flags != Upsert {
|
||||
t.Errorf("Expected Upsert flag, got %v", updateMsg.Body.Flags)
|
||||
}
|
||||
|
||||
if updateMsg.DatabaseName != "test" {
|
||||
t.Errorf("Expected database 'test', got '%s'", updateMsg.DatabaseName)
|
||||
if updateMsg.Body.DatabaseName != "test" {
|
||||
t.Errorf("Expected database 'test', got '%s'", updateMsg.Body.DatabaseName)
|
||||
}
|
||||
|
||||
if updateMsg.CollName != "coll" {
|
||||
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.CollName)
|
||||
if updateMsg.Body.CollName != "coll" {
|
||||
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.Body.CollName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInvalidUpdate(t *testing.T) {
|
||||
// 测试短数据包
|
||||
shortData := []byte{1, 0, 0, 0} // 长度不足4字节
|
||||
shortData := []byte{1, 0, 0, 0} // 长度不足4字节
|
||||
_, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值
|
||||
if err == nil {
|
||||
t.Error("Expected error for short data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMessage_Functionality(t *testing.T) {
|
||||
// 创建测试文档
|
||||
testDoc := map[string]interface{}{
|
||||
"name": "Alice",
|
||||
"age": int32(30),
|
||||
}
|
||||
|
||||
// 创建UpdateMessage实例
|
||||
updateMsg := &UpdateMessage{
|
||||
Body: struct {
|
||||
Flags UpdateFlags
|
||||
DatabaseName string
|
||||
CollName string
|
||||
Query map[string]interface{}
|
||||
UpdateSpec map[string]interface{}
|
||||
}{
|
||||
UpdateSpec: map[string]interface{}{
|
||||
"$set": map[string]interface{}{
|
||||
"name": "Bob",
|
||||
},
|
||||
"$inc": map[string]interface{}{
|
||||
"age": int32(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 执行更新操作
|
||||
updatedDoc, err := updateMsg.Update(testDoc)
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
if name, ok := updatedDoc["name"].(string); !ok || name != "Bob" {
|
||||
t.Errorf("Expected name 'Bob', got '%v'", name)
|
||||
}
|
||||
|
||||
if age, ok := updatedDoc["age"].(int32); !ok || age != 31 {
|
||||
t.Errorf("Expected age 31, got %v", age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiUpdate(t *testing.T) {
|
||||
// 创建测试文档集
|
||||
docs := []map[string]interface{}{
|
||||
{"name": "Alice", "age": int32(30)},
|
||||
{"name": "Charlie", "age": int32(25)},
|
||||
}
|
||||
|
||||
// 创建支持多更新的UpdateMessage
|
||||
updateMsg := &UpdateMessage{
|
||||
Body: struct {
|
||||
Flags UpdateFlags
|
||||
DatabaseName string
|
||||
CollName string
|
||||
Query map[string]interface{}
|
||||
UpdateSpec map[string]interface{}
|
||||
}{
|
||||
Flags: MultiUpdate,
|
||||
UpdateSpec: map[string]interface{}{
|
||||
"$inc": map[string]interface{}{
|
||||
"age": int32(2),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 执行多文档更新
|
||||
for i := range docs {
|
||||
updatedDoc, err := updateMsg.Update(docs[i])
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed for doc %d: %v", i, err)
|
||||
}
|
||||
|
||||
docs[i] = updatedDoc
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
expectedAges := []int32{32, 27}
|
||||
for i, doc := range docs {
|
||||
if age, ok := doc["age"].(int32); !ok || age != expectedAges[i] {
|
||||
t.Errorf("Doc %d: expected age %d, got %v", i, expectedAges[i], age)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue