add $inc and $set
This commit is contained in:
parent
9c08d68291
commit
8d25d60018
|
@ -60,19 +60,19 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||||
case protocol.OP_UPDATE:
|
case protocol.OP_UPDATE:
|
||||||
updateMsg := message.Body.(*protocol.UpdateMessage)
|
updateMsg := message.Body.(*protocol.UpdateMessage)
|
||||||
// 序列化查询和更新文档为BSON格式
|
// 序列化查询和更新文档为BSON格式
|
||||||
queryBson, err := protocol.BsonMarshal(updateMsg.Query)
|
queryBson, err := protocol.BsonMarshal(updateMsg.Body.Query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("查询文档序列化失败", "error", err)
|
log.Error("查询文档序列化失败", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
updateBson, err := protocol.BsonMarshal(updateMsg.Update)
|
updateBson, err := protocol.BsonMarshal(updateMsg.Body.UpdateSpec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("更新文档序列化失败", "error", err)
|
log.Error("更新文档序列化失败", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.storage.Update(updateMsg.DatabaseName, updateMsg.CollName, queryBson, updateBson)
|
err = s.storage.Update(updateMsg.Body.DatabaseName, updateMsg.Body.CollName, queryBson, updateBson)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("存储层更新失败", "error", err)
|
log.Error("存储层更新失败", "error", err)
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BSONElement represents a single BSON element
|
// BSONElement represents a single BSON element
|
||||||
|
@ -67,7 +68,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
|
||||||
pos = keyEnd + 1
|
pos = keyEnd + 1
|
||||||
|
|
||||||
// 根据元素类型解析值
|
// 根据元素类型解析值
|
||||||
value, newPos, err := parseBSONValue(elementType, data[pos:])
|
value, newPos, err := parseBSONValue(elementType, data[pos:], 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err)
|
return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err)
|
||||||
}
|
}
|
||||||
|
@ -80,7 +81,7 @@ func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseBSONValue 解析特定类型的BSON值
|
// parseBSONValue 解析特定类型的BSON值
|
||||||
func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
|
func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, error) {
|
||||||
switch elementType {
|
switch elementType {
|
||||||
case 0x10: // Int32
|
case 0x10: // Int32
|
||||||
if len(data) < 4 {
|
if len(data) < 4 {
|
||||||
|
@ -121,7 +122,7 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取字符串内容(忽略最后的终止符)
|
// 读取字符串内容(忽略最后的终止符)
|
||||||
value := string(data[4 : 4+strLength])
|
value := strings.Trim(string(data[4:4+strLength]), "\x00")
|
||||||
return value, 4 + strLength, nil
|
return value, 4 + strLength, nil
|
||||||
|
|
||||||
case 0x08: // Boolean
|
case 0x08: // Boolean
|
||||||
|
@ -146,6 +147,26 @@ func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) {
|
||||||
return nil, 0, fmt.Errorf("data too short for EmbeddedDocument")
|
return nil, 0, fmt.Errorf("data too short for EmbeddedDocument")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否为更新操作符(以$开头的键名)
|
||||||
|
if pos > 0 && data[pos-1] == 0x02 { // 前一个字节是字符串类型标记
|
||||||
|
// 查找前一个键名
|
||||||
|
for i := pos - 2; i >= 0; i-- {
|
||||||
|
if data[i] == 0x02 { // 找到字符串类型标记
|
||||||
|
keyLen := int(binary.LittleEndian.Uint32(data[i+1 : i+5]))
|
||||||
|
keyStart := i + 5
|
||||||
|
if keyStart+keyLen <= pos {
|
||||||
|
key := string(data[keyStart : keyStart+keyLen-1])
|
||||||
|
if len(key) > 0 && key[0] == '$' {
|
||||||
|
// 这是一个更新操作符
|
||||||
|
value := make(map[string]interface{})
|
||||||
|
value[key] = "operator_placeholder"
|
||||||
|
return value, docLength, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 解析嵌入文档内容
|
// 解析嵌入文档内容
|
||||||
subDoc, _, err := parseBSON(data[0:docLength])
|
subDoc, _, err := parseBSON(data[0:docLength])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -75,6 +75,65 @@ func TestParseBSON_String(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseBSON_UpdateOperators(t *testing.T) {
|
||||||
|
// 测试$set操作符解析
|
||||||
|
setData := []byte{
|
||||||
|
0x1f, 0, 0, 0, // 文档长度=21字节
|
||||||
|
0x03, // 嵌入文档类型
|
||||||
|
'$', 's', 'e', 't', 0x00, // $set操作符键名
|
||||||
|
0x14, 0, 0, 0, // 子文档长度
|
||||||
|
0x02, // 字符串类型
|
||||||
|
'n', 'a', 'm', 'e', 0x00, // 键名"name"
|
||||||
|
5, 0, 0, 0, // 字符串长度=6字节(含终止符)
|
||||||
|
'j', 'o', 'h', 'n', 0x00, // 字符串内容
|
||||||
|
0x00, // 结束符
|
||||||
|
0x00, // 主文档结束符
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := ParseBSON(setData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseBSON failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, ok := result["$set"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected key '$set' not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["name"] != "john" {
|
||||||
|
t.Errorf("Expected $set{name: 'john'}, got %v", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试$inc操作符解析
|
||||||
|
incData := []byte{
|
||||||
|
0x19, 0, 0, 0, // 文档长度=21字节
|
||||||
|
0x03, // 嵌入文档类型
|
||||||
|
'$', 'i', 'n', 'c', 0x00, // $inc操作符键名
|
||||||
|
0x0e, 0, 0, 0, // 子文档长度
|
||||||
|
0x10, // Int32类型
|
||||||
|
'a', 'g', 'e', 0x00, // 键名"age"
|
||||||
|
22, 0, 0, 0, // 值=22
|
||||||
|
0x00, // 结束符
|
||||||
|
0x00, // 主文档结束符
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = ParseBSON(incData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseBSON failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, ok = result["$inc"]
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected key '$inc' not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if subDoc, ok := value.(map[string]interface{}); !ok || subDoc["age"] != int32(22) {
|
||||||
|
t.Errorf("Expected $inc{age: 22}, got %v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseBSON_ErrorCases(t *testing.T) {
|
func TestParseBSON_ErrorCases(t *testing.T) {
|
||||||
// 测试数据过短的情况
|
// 测试数据过短的情况
|
||||||
shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档(仅包含长度字段)
|
shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档(仅包含长度字段)
|
||||||
|
|
|
@ -7,13 +7,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UpdateFlags 更新操作标志位
|
// UpdateFlags are the flags for OP_UPDATE
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Update操作的标志位常量
|
Upsert = 1 << iota
|
||||||
Upsert = 1 << 0
|
MultiUpdate // 标志位用于多文档更新
|
||||||
MultiUpdate = 1 << 1
|
|
||||||
WriteConcern = 1 << 3 // 3.x驱动已弃用
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Header 消息头
|
// Header 消息头
|
||||||
|
@ -99,11 +96,19 @@ func parseUpdate(data []byte) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &UpdateMessage{
|
return &UpdateMessage{
|
||||||
Flags: flags,
|
Body: struct {
|
||||||
DatabaseName: dbName,
|
Flags UpdateFlags
|
||||||
CollName: collName,
|
DatabaseName string
|
||||||
Query: queryDoc,
|
CollName string
|
||||||
Update: updateDoc,
|
Query map[string]interface{}
|
||||||
|
UpdateSpec map[string]interface{}
|
||||||
|
}{
|
||||||
|
Flags: flags,
|
||||||
|
DatabaseName: dbName,
|
||||||
|
CollName: collName,
|
||||||
|
Query: queryDoc,
|
||||||
|
UpdateSpec: updateDoc,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,11 +279,50 @@ type InsertMessage struct {
|
||||||
Documents []map[string]interface{} // 要插入的文档
|
Documents []map[string]interface{} // 要插入的文档
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMessage OP_UPDATE消息体结构
|
// UpdateMessage represents an OP_UPDATE message
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Flags UpdateFlags // 更新标志
|
Header Header
|
||||||
DatabaseName string // 数据库名称
|
Body struct {
|
||||||
CollName string // 集合名称
|
Flags UpdateFlags
|
||||||
Query map[string]interface{} // 查询条件
|
DatabaseName string
|
||||||
Update map[string]interface{} // 更新操作
|
CollName string
|
||||||
|
Query map[string]interface{}
|
||||||
|
UpdateSpec map[string]interface{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update applies the update operation to a document
|
||||||
|
func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) {
|
||||||
|
result := make(map[string]interface{})
|
||||||
|
for key, value := range doc {
|
||||||
|
result[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process $set operator
|
||||||
|
if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok {
|
||||||
|
for key, value := range setOp {
|
||||||
|
result[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process $inc operator
|
||||||
|
if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok {
|
||||||
|
for key, value := range incOp {
|
||||||
|
if num, ok := value.(int32); ok {
|
||||||
|
if current, exists := result[key]; exists {
|
||||||
|
if currentNum, ok := current.(int32); ok {
|
||||||
|
result[key] = currentNum + num
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("cannot apply $inc to non-integer field")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result[key] = num
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("invalid increment value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,10 +8,10 @@ func TestParseUpdate(t *testing.T) {
|
||||||
// 构造测试数据(最小有效Update消息)
|
// 构造测试数据(最小有效Update消息)
|
||||||
data := []byte{
|
data := []byte{
|
||||||
// Header
|
// Header
|
||||||
16, 0, 0, 0, // MessageLength (16 bytes)
|
16, 0, 0, 0, // MessageLength (16 bytes)
|
||||||
1, 0, 0, 0, // RequestID
|
1, 0, 0, 0, // RequestID
|
||||||
0, 0, 0, 0, // ResponseTo
|
0, 0, 0, 0, // ResponseTo
|
||||||
209, 7, 0, 0, // OpCode (OP_UPDATE)
|
209, 7, 0, 0, // OpCode (OP_UPDATE)
|
||||||
|
|
||||||
// UpdateFlags (Upsert=1)
|
// UpdateFlags (Upsert=1)
|
||||||
1, 0, 0, 0,
|
1, 0, 0, 0,
|
||||||
|
@ -43,24 +43,111 @@ func TestParseUpdate(t *testing.T) {
|
||||||
t.Fatalf("Expected UpdateMessage, got %T", msg.Body)
|
t.Fatalf("Expected UpdateMessage, got %T", msg.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateMsg.Flags != Upsert {
|
if updateMsg.Body.Flags != Upsert {
|
||||||
t.Errorf("Expected Upsert flag, got %v", updateMsg.Flags)
|
t.Errorf("Expected Upsert flag, got %v", updateMsg.Body.Flags)
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateMsg.DatabaseName != "test" {
|
if updateMsg.Body.DatabaseName != "test" {
|
||||||
t.Errorf("Expected database 'test', got '%s'", updateMsg.DatabaseName)
|
t.Errorf("Expected database 'test', got '%s'", updateMsg.Body.DatabaseName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateMsg.CollName != "coll" {
|
if updateMsg.Body.CollName != "coll" {
|
||||||
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.CollName)
|
t.Errorf("Expected collection 'coll', got '%s'", updateMsg.Body.CollName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseInvalidUpdate(t *testing.T) {
|
func TestParseInvalidUpdate(t *testing.T) {
|
||||||
// 测试短数据包
|
// 测试短数据包
|
||||||
shortData := []byte{1, 0, 0, 0} // 长度不足4字节
|
shortData := []byte{1, 0, 0, 0} // 长度不足4字节
|
||||||
_, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值
|
_, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for short data")
|
t.Error("Expected error for short data")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateMessage_Functionality(t *testing.T) {
|
||||||
|
// 创建测试文档
|
||||||
|
testDoc := map[string]interface{}{
|
||||||
|
"name": "Alice",
|
||||||
|
"age": int32(30),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建UpdateMessage实例
|
||||||
|
updateMsg := &UpdateMessage{
|
||||||
|
Body: struct {
|
||||||
|
Flags UpdateFlags
|
||||||
|
DatabaseName string
|
||||||
|
CollName string
|
||||||
|
Query map[string]interface{}
|
||||||
|
UpdateSpec map[string]interface{}
|
||||||
|
}{
|
||||||
|
UpdateSpec: map[string]interface{}{
|
||||||
|
"$set": map[string]interface{}{
|
||||||
|
"name": "Bob",
|
||||||
|
},
|
||||||
|
"$inc": map[string]interface{}{
|
||||||
|
"age": int32(1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行更新操作
|
||||||
|
updatedDoc, err := updateMsg.Update(testDoc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证结果
|
||||||
|
if name, ok := updatedDoc["name"].(string); !ok || name != "Bob" {
|
||||||
|
t.Errorf("Expected name 'Bob', got '%v'", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if age, ok := updatedDoc["age"].(int32); !ok || age != 31 {
|
||||||
|
t.Errorf("Expected age 31, got %v", age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultiUpdate(t *testing.T) {
|
||||||
|
// 创建测试文档集
|
||||||
|
docs := []map[string]interface{}{
|
||||||
|
{"name": "Alice", "age": int32(30)},
|
||||||
|
{"name": "Charlie", "age": int32(25)},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建支持多更新的UpdateMessage
|
||||||
|
updateMsg := &UpdateMessage{
|
||||||
|
Body: struct {
|
||||||
|
Flags UpdateFlags
|
||||||
|
DatabaseName string
|
||||||
|
CollName string
|
||||||
|
Query map[string]interface{}
|
||||||
|
UpdateSpec map[string]interface{}
|
||||||
|
}{
|
||||||
|
Flags: MultiUpdate,
|
||||||
|
UpdateSpec: map[string]interface{}{
|
||||||
|
"$inc": map[string]interface{}{
|
||||||
|
"age": int32(2),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行多文档更新
|
||||||
|
for i := range docs {
|
||||||
|
updatedDoc, err := updateMsg.Update(docs[i])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update failed for doc %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
docs[i] = updatedDoc
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证结果
|
||||||
|
expectedAges := []int32{32, 27}
|
||||||
|
for i, doc := range docs {
|
||||||
|
if age, ok := doc["age"].(int32); !ok || age != expectedAges[i] {
|
||||||
|
t.Errorf("Doc %d: expected age %d, got %v", i, expectedAges[i], age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue