goaidb/protocol/parser.go

344 lines
8.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package protocol 实现MongoDB协议解析
package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"strings"
)
// UpdateFlags are the flags for OP_UPDATE
const (
Upsert = 1 << iota
MultiUpdate // 标志位用于多文档更新
)
// Header 消息头
type Header struct {
MessageLength int32
RequestID int32
ResponseTo int32
OpCode OpCode
}
// OpCode 操作码
type OpCode int32
const (
OP_REPLY OpCode = 1
OP_MSG OpCode = 2
OP_UPDATE OpCode = 2001
OP_INSERT OpCode = 2002
RESERVED OpCode = 2003
OP_QUERY OpCode = 2004
OP_GET_MORE OpCode = 2005
OP_DELETE OpCode = 2006
OP_KILL_CURSORS OpCode = 2007
OP_COMMAND OpCode = 2010
OP_COMMAND_REPLY OpCode = 2011
OP_COMPRESSED OpCode = 2012
OP_ENCRYPTED OpCode = 2013
)
// Message MongoDB协议消息结构
type Message struct {
Header Header
OpCode OpCode
OriginalBody []byte // 原始消息体(解析前)
Body interface{} // 解析后的消息体
}
// 解析更新请求
func parseUpdate(data []byte) (interface{}, error) {
// 检查数据长度最小需要4字节的flags + 1字节的数据库名终止符
if len(data) < 5 {
return nil, fmt.Errorf("update message data too short")
}
// 读取flags
flags := UpdateFlags(binary.LittleEndian.Uint32(data[0:4]))
// 查找数据库名结束位置C风格字符串以0终止
dbEnd := bytes.IndexByte(data[4:], 0)
if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated")
}
dbEnd += 4 // 调整到正确的位置
// 提取数据库名称
dbName := string(data[4:dbEnd])
// 剩余数据包含集合名、查询文档和更新文档
remaining := data[dbEnd+1:]
// 查找集合名结束位置
collEnd := bytes.IndexByte(remaining, 0)
if collEnd == -1 {
return nil, fmt.Errorf("collection name not null terminated")
}
// 提取集合名
collName := string(remaining[:collEnd])
// 剩余数据包含查询文档和更新文档
bsonData := remaining[collEnd+1:]
// 解析BSON文档
queryDoc, _, err := parseBSON(bsonData)
if err != nil {
return nil, fmt.Errorf("failed to parse query document: %v", err)
}
// 解析更新文档
updateDoc, _, err := parseBSON(bsonData)
if err != nil {
return nil, fmt.Errorf("failed to parse update document: %v", err)
}
return &UpdateMessage{
Body: struct {
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}{
Flags: flags,
DatabaseName: dbName,
CollName: collName,
Query: queryDoc,
UpdateSpec: updateDoc,
},
}, nil
}
// ParseMessage 解析MongoDB协议消息
func ParseMessage(data []byte) (*Message, error) {
// 最小消息长度为16字节消息头长度
if len(data) < 16 {
return nil, fmt.Errorf("data too short for message header")
}
// 验证消息长度是否完整
messageLength := int(binary.LittleEndian.Uint32(data[0:4]))
if len(data) < messageLength {
return nil, fmt.Errorf("data too short for complete message")
}
// 截取实际的消息数据(可能有多条消息)
actualData := data[:messageLength]
// 解析消息头
header := &Header{
MessageLength: int32(binary.LittleEndian.Uint32(actualData[0:4])),
RequestID: int32(binary.LittleEndian.Uint32(actualData[4:8])),
ResponseTo: int32(binary.LittleEndian.Uint32(actualData[8:12])),
OpCode: OpCode(binary.LittleEndian.Uint32(actualData[12:16])),
}
// 获取消息体
body := data[16:]
// 解析特定操作码的消息体
var parsedBody interface{}
switch header.OpCode {
case OP_QUERY:
query, err := parseQuery(body)
if err != nil {
return nil, err
}
parsedBody = query
case OP_INSERT:
insert, err := parseInsert(body)
if err != nil {
return nil, err
}
parsedBody = insert
case OP_UPDATE:
update, err := parseUpdate(body)
if err != nil {
return nil, err
}
parsedBody = update
// 这里可以添加更多操作码的解析逻辑
default:
// 未知操作码,保留原始数据
parsedBody = body
}
return &Message{
Header: *header,
OpCode: header.OpCode,
OriginalBody: body,
Body: parsedBody,
}, nil
}
// 解析查询请求
func parseQuery(data []byte) (interface{}, error) {
// 实现具体的查询消息解析逻辑
// 这里返回原始数据作为占位符
if len(data) < 4 {
return nil, fmt.Errorf("query data too short")
}
// 示例:读取查询标志
flags := binary.LittleEndian.Uint32(data[0:4])
// 提取数据库名称
dbEnd := bytes.IndexByte(data[4:], 0)
if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated")
}
dbcolName := string(data[4 : dbEnd+4])
dcnames := strings.Split(dbcolName, ".")
dbName := dcnames[0]
collName := dcnames[1]
// 剩余数据包含集合名和查询条件
remaining := data[dbEnd+5:] // 跳过终止符
var numberToSkip, numberToReturn int32
binary.Read(bytes.NewReader(remaining[0:4]), binary.LittleEndian, &numberToSkip)
binary.Read(bytes.NewReader(remaining[4:8]), binary.LittleEndian, &numberToReturn)
// 提取集合名
// collEnd := bytes.IndexByte(remaining, 0)
// if collEnd == -1 {
// return nil, fmt.Errorf("collection name not null terminated")
// }
// if collEnd == 0 {
// collEnd = bytes.IndexFunc(remaining, func(r rune) bool {
// return r != 0
// })
// }
// collName := string(remaining[:collEnd])
// 解析查询条件
queryDoc, _, err := parseBSON(remaining[8:])
if err != nil {
return nil, fmt.Errorf("failed to parse query conditions: %v", err)
}
return &QueryMessage{
Flags: flags,
DatabaseName: dbName,
CollName: collName,
NumberToSkip: numberToSkip,
NumberToReturn: numberToReturn,
Query: queryDoc,
}, nil
}
// 解析插入请求
func parseInsert(data []byte) (interface{}, error) {
// 实现具体的插入消息解析逻辑
if len(data) < 4 {
return nil, fmt.Errorf("insert data too short")
}
// 示例:读取插入标志
flags := binary.LittleEndian.Uint32(data[0:4])
// 提取数据库名称
dbEnd := bytes.IndexByte(data[4:], 0)
if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated")
}
dbName := string(data[4 : dbEnd+4])
// 剩余数据包含集合名和文档数据
remaining := data[dbEnd+5:] // 跳过终止符
// 提取集合名
collEnd := bytes.IndexByte(remaining, 0)
if collEnd == -1 {
return nil, fmt.Errorf("collection name not null terminated")
}
collName := string(remaining[:collEnd])
// 解析文档数据
documents := make([]map[string]interface{}, 0)
rest := remaining[collEnd+1:]
for len(rest) > 0 {
doc, remainingData, err := parseBSON(rest)
if err != nil {
return nil, fmt.Errorf("failed to parse document: %v", err)
}
documents = append(documents, doc)
rest = remainingData
}
return &InsertMessage{
Flags: flags,
DatabaseName: dbName,
CollName: collName,
Documents: documents,
}, nil
}
// QueryMessage OP_QUERY消息体结构
type QueryMessage struct {
Flags uint32 // 查询标志
DatabaseName string // 数据库名称
CollName string // 集合名称
NumberToSkip int32
NumberToReturn int32
Query map[string]interface{} // 查询条件
}
// InsertMessage OP_INSERT消息体结构
type InsertMessage struct {
Flags uint32 // 插入标志
DatabaseName string // 数据库名称
CollName string // 集合名称
Documents []map[string]interface{} // 要插入的文档
}
// UpdateMessage represents an OP_UPDATE message
type UpdateMessage struct {
Header Header
Body struct {
Flags UpdateFlags
DatabaseName string
CollName string
Query map[string]interface{}
UpdateSpec map[string]interface{}
}
}
// Update applies the update operation to a document
func (u *UpdateMessage) Update(doc map[string]interface{}) (map[string]interface{}, error) {
result := make(map[string]interface{})
for key, value := range doc {
result[key] = value
}
// Process $set operator
if setOp, ok := u.Body.UpdateSpec["$set"].(map[string]interface{}); ok {
for key, value := range setOp {
result[key] = value
}
}
// Process $inc operator
if incOp, ok := u.Body.UpdateSpec["$inc"].(map[string]interface{}); ok {
for key, value := range incOp {
if num, ok := value.(int32); ok {
if current, exists := result[key]; exists {
if currentNum, ok := current.(int32); ok {
result[key] = currentNum + num
} else {
return nil, fmt.Errorf("cannot apply $inc to non-integer field")
}
} else {
result[key] = num
}
} else {
return nil, fmt.Errorf("invalid increment value")
}
}
}
return result, nil
}