goaidb/protocol/parser.go

344 lines
8.7 KiB
Go
Raw Permalink Normal View History

2025-06-05 17:59:27 +08:00
// Package protocol 实现MongoDB协议解析
package protocol
import (
2025-06-06 22:12:24 +08:00
"bytes"
2025-06-05 17:59:27 +08:00
"encoding/binary"
"fmt"
2025-06-07 10:58:28 +08:00
"strings"
2025-06-05 17:59:27 +08:00
)
2025-06-06 22:54:22 +08:00
// UpdateFlags are the flags for OP_UPDATE
2025-06-06 22:12:24 +08:00
const (
2025-06-07 10:58:28 +08:00
Upsert = 1 << iota
2025-06-06 22:54:22 +08:00
MultiUpdate // 标志位用于多文档更新
2025-06-06 22:12:24 +08:00
)
2025-06-05 17:59:27 +08:00
// Header 消息头
type Header struct {
MessageLength int32
RequestID int32
ResponseTo int32
OpCode OpCode
}
// OpCode 操作码
type OpCode int32
const (
2025-06-06 22:12:24 +08:00
OP_REPLY OpCode = 1
OP_MSG OpCode = 2
OP_UPDATE OpCode = 2001
OP_INSERT OpCode = 2002
RESERVED OpCode = 2003
OP_QUERY OpCode = 2004
OP_GET_MORE OpCode = 2005
OP_DELETE OpCode = 2006
OP_KILL_CURSORS OpCode = 2007
OP_COMMAND OpCode = 2010
OP_COMMAND_REPLY OpCode = 2011
OP_COMPRESSED OpCode = 2012
OP_ENCRYPTED OpCode = 2013
2025-06-05 17:59:27 +08:00
)
2025-06-06 22:12:24 +08:00
// Message MongoDB协议消息结构
type Message struct {
Header Header
OpCode OpCode
OriginalBody []byte // 原始消息体(解析前)
Body interface{} // 解析后的消息体
}
// 解析更新请求
func parseUpdate(data []byte) (interface{}, error) {
// 检查数据长度最小需要4字节的flags + 1字节的数据库名终止符
if len(data) < 5 {
return nil, fmt.Errorf("update message data too short")
}
// 读取flags
flags := UpdateFlags(binary.LittleEndian.Uint32(data[0:4]))
// 查找数据库名结束位置C风格字符串以0终止
dbEnd := bytes.IndexByte(data[4:], 0)
if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated")
}
dbEnd += 4 // 调整到正确的位置
// 提取数据库名称
dbName := string(data[4:dbEnd])
// 剩余数据包含集合名、查询文档和更新文档
remaining := data[dbEnd+1:]
// 查找集合名结束位置
collEnd := bytes.IndexByte(remaining, 0)
if collEnd == -1 {
return nil, fmt.Errorf("collection name not null terminated")
}
// 提取集合名
collName := string(remaining[:collEnd])
// 剩余数据包含查询文档和更新文档
bsonData := remaining[collEnd+1:]
// 解析BSON文档
queryDoc, _, err := parseBSON(bsonData)
if err != nil {
return nil, fmt.Errorf("failed to parse query document: %v", err)
}
// 解析更新文档
updateDoc, _, err := parseBSON(bsonData)
if err != nil {
return nil, fmt.Errorf("failed to parse update document: %v", err)
}
return &UpdateMessage{
2025-06-06 22:54:22 +08:00
Body: struct {
2025-06-07 10:58:28 +08:00
Flags UpdateFlags
2025-06-06 22:54:22 +08:00
DatabaseName string
2025-06-07 10:58:28 +08:00
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
2025-06-06 22:54:22 +08:00
}{
Flags: flags,
DatabaseName: dbName,
CollName: collName,
Query: queryDoc,
UpdateSpec: updateDoc,
},
2025-06-06 22:12:24 +08:00
}, nil
}
2025-06-05 17:59:27 +08:00
// ParseMessage 解析MongoDB协议消息
func ParseMessage(data []byte) (*Message, error) {
2025-06-06 22:12:24 +08:00
// 最小消息长度为16字节消息头长度
2025-06-05 17:59:27 +08:00
if len(data) < 16 {
return nil, fmt.Errorf("data too short for message header")
}
2025-06-06 22:12:24 +08:00
// 验证消息长度是否完整
messageLength := int(binary.LittleEndian.Uint32(data[0:4]))
if len(data) < messageLength {
return nil, fmt.Errorf("data too short for complete message")
}
// 截取实际的消息数据(可能有多条消息)
actualData := data[:messageLength]
// 解析消息头
2025-06-05 17:59:27 +08:00
header := &Header{
2025-06-06 22:12:24 +08:00
MessageLength: int32(binary.LittleEndian.Uint32(actualData[0:4])),
RequestID: int32(binary.LittleEndian.Uint32(actualData[4:8])),
ResponseTo: int32(binary.LittleEndian.Uint32(actualData[8:12])),
OpCode: OpCode(binary.LittleEndian.Uint32(actualData[12:16])),
2025-06-05 17:59:27 +08:00
}
2025-06-06 22:12:24 +08:00
// 获取消息体
2025-06-05 17:59:27 +08:00
body := data[16:]
// 解析特定操作码的消息体
var parsedBody interface{}
switch header.OpCode {
case OP_QUERY:
query, err := parseQuery(body)
if err != nil {
return nil, err
}
parsedBody = query
case OP_INSERT:
insert, err := parseInsert(body)
if err != nil {
return nil, err
}
parsedBody = insert
2025-06-06 22:12:24 +08:00
case OP_UPDATE:
update, err := parseUpdate(body)
if err != nil {
return nil, err
}
parsedBody = update
2025-06-05 17:59:27 +08:00
// 这里可以添加更多操作码的解析逻辑
default:
// 未知操作码,保留原始数据
parsedBody = body
}
return &Message{
Header: *header,
OpCode: header.OpCode,
OriginalBody: body,
Body: parsedBody,
}, nil
}
// 解析查询请求
func parseQuery(data []byte) (interface{}, error) {
// 实现具体的查询消息解析逻辑
// 这里返回原始数据作为占位符
2025-06-06 22:12:24 +08:00
if len(data) < 4 {
return nil, fmt.Errorf("query data too short")
}
// 示例:读取查询标志
flags := binary.LittleEndian.Uint32(data[0:4])
// 提取数据库名称
dbEnd := bytes.IndexByte(data[4:], 0)
if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated")
}
2025-06-07 10:58:28 +08:00
dbcolName := string(data[4 : dbEnd+4])
dcnames := strings.Split(dbcolName, ".")
dbName := dcnames[0]
collName := dcnames[1]
2025-06-06 22:12:24 +08:00
// 剩余数据包含集合名和查询条件
remaining := data[dbEnd+5:] // 跳过终止符
2025-06-07 10:58:28 +08:00
var numberToSkip, numberToReturn int32
binary.Read(bytes.NewReader(remaining[0:4]), binary.LittleEndian, &numberToSkip)
binary.Read(bytes.NewReader(remaining[4:8]), binary.LittleEndian, &numberToReturn)
2025-06-06 22:12:24 +08:00
// 提取集合名
2025-06-07 10:58:28 +08:00
// collEnd := bytes.IndexByte(remaining, 0)
// if collEnd == -1 {
// return nil, fmt.Errorf("collection name not null terminated")
// }
// if collEnd == 0 {
// collEnd = bytes.IndexFunc(remaining, func(r rune) bool {
// return r != 0
// })
// }
// collName := string(remaining[:collEnd])
2025-06-06 22:12:24 +08:00
// 解析查询条件
2025-06-07 10:58:28 +08:00
queryDoc, _, err := parseBSON(remaining[8:])
2025-06-06 22:12:24 +08:00
if err != nil {
return nil, fmt.Errorf("failed to parse query conditions: %v", err)
}
return &QueryMessage{
2025-06-07 10:58:28 +08:00
Flags: flags,
DatabaseName: dbName,
CollName: collName,
NumberToSkip: numberToSkip,
NumberToReturn: numberToReturn,
Query: queryDoc,
2025-06-06 22:12:24 +08:00
}, nil
2025-06-05 17:59:27 +08:00
}
// 解析插入请求
func parseInsert(data []byte) (interface{}, error) {
// 实现具体的插入消息解析逻辑
2025-06-06 22:12:24 +08:00
if len(data) < 4 {
return nil, fmt.Errorf("insert data too short")
}
// 示例:读取插入标志
flags := binary.LittleEndian.Uint32(data[0:4])
// 提取数据库名称
dbEnd := bytes.IndexByte(data[4:], 0)
if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated")
}
dbName := string(data[4 : dbEnd+4])
// 剩余数据包含集合名和文档数据
remaining := data[dbEnd+5:] // 跳过终止符
// 提取集合名
collEnd := bytes.IndexByte(remaining, 0)
if collEnd == -1 {
return nil, fmt.Errorf("collection name not null terminated")
}
collName := string(remaining[:collEnd])
// 解析文档数据
documents := make([]map[string]interface{}, 0)
rest := remaining[collEnd+1:]
for len(rest) > 0 {
doc, remainingData, err := parseBSON(rest)
if err != nil {
return nil, fmt.Errorf("failed to parse document: %v", err)
}
documents = append(documents, doc)
rest = remainingData
}
return &InsertMessage{
Flags: flags,
DatabaseName: dbName,
CollName: collName,
Documents: documents,
}, nil
}
// QueryMessage OP_QUERY消息体结构
type QueryMessage struct {
2025-06-07 10:58:28 +08:00
Flags uint32 // 查询标志
DatabaseName string // 数据库名称
CollName string // 集合名称
NumberToSkip int32
NumberToReturn int32
Query map[string]interface{} // 查询条件
2025-06-06 22:12:24 +08:00
}
// InsertMessage OP_INSERT消息体结构
type InsertMessage struct {
Flags uint32 // 插入标志
DatabaseName string // 数据库名称
CollName string // 集合名称
Documents []map[string]interface{} // 要插入的文档
}
2025-06-06 22:54:22 +08:00
// UpdateMessage represents an OP_UPDATE message
2025-06-06 22:12:24 +08:00
type UpdateMessage struct {
2025-06-07 10:58:28 +08:00
Header Header
Body struct {
2025-06-06 22:54:22 +08:00
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}
}
// Update applies the update operation to a document
func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) {
result := make(map[string]interface{})
for key, value := range doc {
result[key] = value
}
// Process $set operator
if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok {
for key, value := range setOp {
result[key] = value
}
}
// Process $inc operator
if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok {
for key, value := range incOp {
if num, ok := value.(int32); ok {
if current, exists := result[key]; exists {
if currentNum, ok := current.(int32); ok {
result[key] = currentNum + num
} else {
return nil, fmt.Errorf("cannot apply $inc to non-integer field")
}
} else {
result[key] = num
}
} else {
return nil, fmt.Errorf("invalid increment value")
}
}
}
return result, nil
2025-06-06 22:12:24 +08:00
}