344 lines
8.7 KiB
Go
344 lines
8.7 KiB
Go
// Package protocol 实现MongoDB协议解析
|
||
package protocol
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/binary"
|
||
"fmt"
|
||
"strings"
|
||
)
|
||
|
||
// UpdateFlags are the flags for OP_UPDATE
|
||
const (
|
||
Upsert = 1 << iota
|
||
MultiUpdate // 标志位用于多文档更新
|
||
)
|
||
|
||
// Header 消息头
|
||
type Header struct {
|
||
MessageLength int32
|
||
RequestID int32
|
||
ResponseTo int32
|
||
OpCode OpCode
|
||
}
|
||
|
||
// OpCode 操作码
|
||
type OpCode int32
|
||
|
||
const (
|
||
OP_REPLY OpCode = 1
|
||
OP_MSG OpCode = 2
|
||
OP_UPDATE OpCode = 2001
|
||
OP_INSERT OpCode = 2002
|
||
RESERVED OpCode = 2003
|
||
OP_QUERY OpCode = 2004
|
||
OP_GET_MORE OpCode = 2005
|
||
OP_DELETE OpCode = 2006
|
||
OP_KILL_CURSORS OpCode = 2007
|
||
OP_COMMAND OpCode = 2010
|
||
OP_COMMAND_REPLY OpCode = 2011
|
||
OP_COMPRESSED OpCode = 2012
|
||
OP_ENCRYPTED OpCode = 2013
|
||
)
|
||
|
||
// Message MongoDB协议消息结构
|
||
type Message struct {
|
||
Header Header
|
||
OpCode OpCode
|
||
OriginalBody []byte // 原始消息体(解析前)
|
||
Body interface{} // 解析后的消息体
|
||
}
|
||
|
||
// 解析更新请求
|
||
func parseUpdate(data []byte) (interface{}, error) {
|
||
// 检查数据长度(最小需要4字节的flags + 1字节的数据库名终止符)
|
||
if len(data) < 5 {
|
||
return nil, fmt.Errorf("update message data too short")
|
||
}
|
||
|
||
// 读取flags
|
||
flags := UpdateFlags(binary.LittleEndian.Uint32(data[0:4]))
|
||
|
||
// 查找数据库名结束位置(C风格字符串,以0终止)
|
||
dbEnd := bytes.IndexByte(data[4:], 0)
|
||
if dbEnd == -1 {
|
||
return nil, fmt.Errorf("database name not null terminated")
|
||
}
|
||
dbEnd += 4 // 调整到正确的位置
|
||
|
||
// 提取数据库名称
|
||
dbName := string(data[4:dbEnd])
|
||
|
||
// 剩余数据包含集合名、查询文档和更新文档
|
||
remaining := data[dbEnd+1:]
|
||
|
||
// 查找集合名结束位置
|
||
collEnd := bytes.IndexByte(remaining, 0)
|
||
if collEnd == -1 {
|
||
return nil, fmt.Errorf("collection name not null terminated")
|
||
}
|
||
|
||
// 提取集合名
|
||
collName := string(remaining[:collEnd])
|
||
|
||
// 剩余数据包含查询文档和更新文档
|
||
bsonData := remaining[collEnd+1:]
|
||
|
||
// 解析BSON文档
|
||
queryDoc, _, err := parseBSON(bsonData)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse query document: %v", err)
|
||
}
|
||
|
||
// 解析更新文档
|
||
updateDoc, _, err := parseBSON(bsonData)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse update document: %v", err)
|
||
}
|
||
|
||
return &UpdateMessage{
|
||
Body: struct {
|
||
Flags UpdateFlags
|
||
DatabaseName string
|
||
CollName string
|
||
Query map[string]interface{}
|
||
UpdateSpec map[string]interface{}
|
||
}{
|
||
Flags: flags,
|
||
DatabaseName: dbName,
|
||
CollName: collName,
|
||
Query: queryDoc,
|
||
UpdateSpec: updateDoc,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// ParseMessage 解析MongoDB协议消息
|
||
func ParseMessage(data []byte) (*Message, error) {
|
||
// 最小消息长度为16字节(消息头长度)
|
||
if len(data) < 16 {
|
||
return nil, fmt.Errorf("data too short for message header")
|
||
}
|
||
|
||
// 验证消息长度是否完整
|
||
messageLength := int(binary.LittleEndian.Uint32(data[0:4]))
|
||
if len(data) < messageLength {
|
||
return nil, fmt.Errorf("data too short for complete message")
|
||
}
|
||
|
||
// 截取实际的消息数据(可能有多条消息)
|
||
actualData := data[:messageLength]
|
||
|
||
// 解析消息头
|
||
header := &Header{
|
||
MessageLength: int32(binary.LittleEndian.Uint32(actualData[0:4])),
|
||
RequestID: int32(binary.LittleEndian.Uint32(actualData[4:8])),
|
||
ResponseTo: int32(binary.LittleEndian.Uint32(actualData[8:12])),
|
||
OpCode: OpCode(binary.LittleEndian.Uint32(actualData[12:16])),
|
||
}
|
||
|
||
// 获取消息体
|
||
body := data[16:]
|
||
|
||
// 解析特定操作码的消息体
|
||
var parsedBody interface{}
|
||
switch header.OpCode {
|
||
case OP_QUERY:
|
||
query, err := parseQuery(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
parsedBody = query
|
||
case OP_INSERT:
|
||
insert, err := parseInsert(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
parsedBody = insert
|
||
case OP_UPDATE:
|
||
update, err := parseUpdate(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
parsedBody = update
|
||
// 这里可以添加更多操作码的解析逻辑
|
||
default:
|
||
// 未知操作码,保留原始数据
|
||
parsedBody = body
|
||
}
|
||
|
||
return &Message{
|
||
Header: *header,
|
||
OpCode: header.OpCode,
|
||
OriginalBody: body,
|
||
Body: parsedBody,
|
||
}, nil
|
||
}
|
||
|
||
// 解析查询请求
|
||
func parseQuery(data []byte) (interface{}, error) {
|
||
// 实现具体的查询消息解析逻辑
|
||
// 这里返回原始数据作为占位符
|
||
if len(data) < 4 {
|
||
return nil, fmt.Errorf("query data too short")
|
||
}
|
||
|
||
// 示例:读取查询标志
|
||
flags := binary.LittleEndian.Uint32(data[0:4])
|
||
|
||
// 提取数据库名称
|
||
dbEnd := bytes.IndexByte(data[4:], 0)
|
||
if dbEnd == -1 {
|
||
return nil, fmt.Errorf("database name not null terminated")
|
||
}
|
||
dbcolName := string(data[4 : dbEnd+4])
|
||
dcnames := strings.Split(dbcolName, ".")
|
||
dbName := dcnames[0]
|
||
collName := dcnames[1]
|
||
|
||
// 剩余数据包含集合名和查询条件
|
||
remaining := data[dbEnd+5:] // 跳过终止符
|
||
var numberToSkip, numberToReturn int32
|
||
binary.Read(bytes.NewReader(remaining[0:4]), binary.LittleEndian, &numberToSkip)
|
||
binary.Read(bytes.NewReader(remaining[4:8]), binary.LittleEndian, &numberToReturn)
|
||
// 提取集合名
|
||
// collEnd := bytes.IndexByte(remaining, 0)
|
||
// if collEnd == -1 {
|
||
// return nil, fmt.Errorf("collection name not null terminated")
|
||
// }
|
||
// if collEnd == 0 {
|
||
// collEnd = bytes.IndexFunc(remaining, func(r rune) bool {
|
||
// return r != 0
|
||
// })
|
||
// }
|
||
// collName := string(remaining[:collEnd])
|
||
|
||
// 解析查询条件
|
||
queryDoc, _, err := parseBSON(remaining[8:])
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse query conditions: %v", err)
|
||
}
|
||
|
||
return &QueryMessage{
|
||
Flags: flags,
|
||
DatabaseName: dbName,
|
||
CollName: collName,
|
||
NumberToSkip: numberToSkip,
|
||
NumberToReturn: numberToReturn,
|
||
Query: queryDoc,
|
||
}, nil
|
||
}
|
||
|
||
// 解析插入请求
|
||
func parseInsert(data []byte) (interface{}, error) {
|
||
// 实现具体的插入消息解析逻辑
|
||
if len(data) < 4 {
|
||
return nil, fmt.Errorf("insert data too short")
|
||
}
|
||
|
||
// 示例:读取插入标志
|
||
flags := binary.LittleEndian.Uint32(data[0:4])
|
||
|
||
// 提取数据库名称
|
||
dbEnd := bytes.IndexByte(data[4:], 0)
|
||
if dbEnd == -1 {
|
||
return nil, fmt.Errorf("database name not null terminated")
|
||
}
|
||
dbName := string(data[4 : dbEnd+4])
|
||
|
||
// 剩余数据包含集合名和文档数据
|
||
remaining := data[dbEnd+5:] // 跳过终止符
|
||
|
||
// 提取集合名
|
||
collEnd := bytes.IndexByte(remaining, 0)
|
||
if collEnd == -1 {
|
||
return nil, fmt.Errorf("collection name not null terminated")
|
||
}
|
||
collName := string(remaining[:collEnd])
|
||
|
||
// 解析文档数据
|
||
documents := make([]map[string]interface{}, 0)
|
||
rest := remaining[collEnd+1:]
|
||
for len(rest) > 0 {
|
||
doc, remainingData, err := parseBSON(rest)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse document: %v", err)
|
||
}
|
||
documents = append(documents, doc)
|
||
rest = remainingData
|
||
}
|
||
|
||
return &InsertMessage{
|
||
Flags: flags,
|
||
DatabaseName: dbName,
|
||
CollName: collName,
|
||
Documents: documents,
|
||
}, nil
|
||
}
|
||
|
||
// QueryMessage OP_QUERY消息体结构
|
||
type QueryMessage struct {
|
||
Flags uint32 // 查询标志
|
||
DatabaseName string // 数据库名称
|
||
CollName string // 集合名称
|
||
NumberToSkip int32
|
||
NumberToReturn int32
|
||
Query map[string]interface{} // 查询条件
|
||
}
|
||
|
||
// InsertMessage OP_INSERT消息体结构
|
||
type InsertMessage struct {
|
||
Flags uint32 // 插入标志
|
||
DatabaseName string // 数据库名称
|
||
CollName string // 集合名称
|
||
Documents []map[string]interface{} // 要插入的文档
|
||
}
|
||
|
||
// UpdateMessage represents an OP_UPDATE message
|
||
type UpdateMessage struct {
|
||
Header Header
|
||
Body struct {
|
||
Flags UpdateFlags
|
||
DatabaseName string
|
||
CollName string
|
||
Query map[string]interface{}
|
||
UpdateSpec map[string]interface{}
|
||
}
|
||
}
|
||
|
||
// Update applies the update operation to a document
|
||
func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) {
|
||
result := make(map[string]interface{})
|
||
for key, value := range doc {
|
||
result[key] = value
|
||
}
|
||
|
||
// Process $set operator
|
||
if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok {
|
||
for key, value := range setOp {
|
||
result[key] = value
|
||
}
|
||
}
|
||
|
||
// Process $inc operator
|
||
if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok {
|
||
for key, value := range incOp {
|
||
if num, ok := value.(int32); ok {
|
||
if current, exists := result[key]; exists {
|
||
if currentNum, ok := current.(int32); ok {
|
||
result[key] = currentNum + num
|
||
} else {
|
||
return nil, fmt.Errorf("cannot apply $inc to non-integer field")
|
||
}
|
||
} else {
|
||
result[key] = num
|
||
}
|
||
} else {
|
||
return nil, fmt.Errorf("invalid increment value")
|
||
}
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|