fix OP_Query parse

This commit is contained in:
kingecg 2025-06-07 10:58:28 +08:00
parent c70cfad89c
commit fdd3a2e7c7
4 changed files with 81 additions and 37 deletions

View File

@ -47,6 +47,7 @@ func (s *Server) handleConnection(conn net.Conn) {
return return
} }
log.WriteOpMsgToFile("opbin.log", uint8(2), buffer[:n])
// 解析MongoDB协议消息 // 解析MongoDB协议消息
message, err := protocol.ParseMessage(buffer[:n]) message, err := protocol.ParseMessage(buffer[:n])
if err != nil { if err != nil {

View File

@ -174,6 +174,20 @@ func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, e
} }
return subDoc, docLength, nil return subDoc, docLength, nil
case 0x04: // Array
if len(data) < 4 {
return nil, 0, fmt.Errorf("data too short for Array length")
}
// 读取数组长度
arrayLength := int(binary.LittleEndian.Uint32(data[0:4]))
if len(data) < arrayLength {
return nil, 0, fmt.Errorf("data too short for Array")
}
subDoc, _, err := parseBSON(data[0:arrayLength])
if err != nil {
return nil, 0, fmt.Errorf("failed to parse array: %v", err)
}
return subDoc, arrayLength, nil
default: default:
return nil, 0, fmt.Errorf("unsupported BSON element type: 0x%02X", elementType) return nil, 0, fmt.Errorf("unsupported BSON element type: 0x%02X", elementType)
@ -183,10 +197,10 @@ func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, e
// BsonMarshal 将map转换为BSON格式的字节流 // BsonMarshal 将map转换为BSON格式的字节流
func BsonMarshal(doc map[string]interface{}) ([]byte, error) { func BsonMarshal(doc map[string]interface{}) ([]byte, error) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
// 写入占位符长度4字节 // 写入占位符长度4字节
buf.Write(make([]byte, 4)) buf.Write(make([]byte, 4))
// 遍历文档元素 // 遍历文档元素
for key, value := range doc { for key, value := range doc {
// 写入元素类型和键名 // 写入元素类型和键名
@ -194,24 +208,24 @@ func BsonMarshal(doc map[string]interface{}) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
buf.WriteByte(elementType) buf.WriteByte(elementType)
buf.WriteString(key) buf.WriteString(key)
buf.WriteByte(0x00) // 键名终止符 buf.WriteByte(0x00) // 键名终止符
// 写入值数据 // 写入值数据
if err := writeBSONValue(buf, elementType, value); err != nil { if err := writeBSONValue(buf, elementType, value); err != nil {
return nil, err return nil, err
} }
} }
// 写入文档结束符 // 写入文档结束符
buf.WriteByte(0x00) buf.WriteByte(0x00)
// 回填文档长度 // 回填文档长度
length := uint32(buf.Len()) length := uint32(buf.Len())
binary.LittleEndian.PutUint32(buf.Bytes(), length) binary.LittleEndian.PutUint32(buf.Bytes(), length)
return buf.Bytes(), nil return buf.Bytes(), nil
} }
@ -248,7 +262,7 @@ func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) erro
b := make([]byte, 4) b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, uint32(v)) binary.LittleEndian.PutUint32(b, uint32(v))
buf.Write(b) buf.Write(b)
case 0x12: // Int64 case 0x12: // Int64
v, ok := value.(int64) v, ok := value.(int64)
if !ok { if !ok {
@ -257,7 +271,7 @@ func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) erro
b := make([]byte, 8) b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, uint64(v)) binary.LittleEndian.PutUint64(b, uint64(v))
buf.Write(b) buf.Write(b)
case 0x01: // Double case 0x01: // Double
v, ok := value.(float64) v, ok := value.(float64)
if !ok { if !ok {
@ -266,7 +280,7 @@ func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) erro
b := make([]byte, 8) b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, math.Float64bits(v)) binary.LittleEndian.PutUint64(b, math.Float64bits(v))
buf.Write(b) buf.Write(b)
case 0x02: // String case 0x02: // String
v, ok := value.(string) v, ok := value.(string)
if !ok { if !ok {
@ -281,7 +295,7 @@ func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) erro
// 写入字符串内容和终止符 // 写入字符串内容和终止符
buf.Write(strBytes) buf.Write(strBytes)
buf.WriteByte(0x00) buf.WriteByte(0x00)
case 0x08: // Boolean case 0x08: // Boolean
v, ok := value.(bool) v, ok := value.(bool)
if !ok { if !ok {
@ -292,10 +306,10 @@ func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) erro
} else { } else {
buf.WriteByte(0x00) buf.WriteByte(0x00)
} }
case 0x0A: // Null case 0x0A: // Null
// 不需要写入任何数据 // 不需要写入任何数据
case 0x03: // EmbeddedDocument case 0x03: // EmbeddedDocument
v, ok := value.(map[string]interface{}) v, ok := value.(map[string]interface{})
if !ok { if !ok {
@ -307,7 +321,7 @@ func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) erro
} }
// 直接写入子文档数据(包含完整的长度信息) // 直接写入子文档数据(包含完整的长度信息)
buf.Write(subDoc) buf.Write(subDoc)
default: default:
return fmt.Errorf("unsupported BSON element type: 0x%02X", elementType) return fmt.Errorf("unsupported BSON element type: 0x%02X", elementType)
} }

View File

@ -5,11 +5,12 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"strings"
) )
// UpdateFlags are the flags for OP_UPDATE // UpdateFlags are the flags for OP_UPDATE
const ( const (
Upsert = 1 << iota Upsert = 1 << iota
MultiUpdate // 标志位用于多文档更新 MultiUpdate // 标志位用于多文档更新
) )
@ -97,11 +98,11 @@ func parseUpdate(data []byte) (interface{}, error) {
return &UpdateMessage{ return &UpdateMessage{
Body: struct { Body: struct {
Flags UpdateFlags Flags UpdateFlags
DatabaseName string DatabaseName string
CollName string CollName string
Query map[string]interface{} Query map[string]interface{}
UpdateSpec map[string]interface{} UpdateSpec map[string]interface{}
}{ }{
Flags: flags, Flags: flags,
DatabaseName: dbName, DatabaseName: dbName,
@ -190,29 +191,41 @@ func parseQuery(data []byte) (interface{}, error) {
if dbEnd == -1 { if dbEnd == -1 {
return nil, fmt.Errorf("database name not null terminated") return nil, fmt.Errorf("database name not null terminated")
} }
dbName := string(data[4 : dbEnd+4]) dbcolName := string(data[4 : dbEnd+4])
dcnames := strings.Split(dbcolName, ".")
dbName := dcnames[0]
collName := dcnames[1]
// 剩余数据包含集合名和查询条件 // 剩余数据包含集合名和查询条件
remaining := data[dbEnd+5:] // 跳过终止符 remaining := data[dbEnd+5:] // 跳过终止符
var numberToSkip, numberToReturn int32
binary.Read(bytes.NewReader(remaining[0:4]), binary.LittleEndian, &numberToSkip)
binary.Read(bytes.NewReader(remaining[4:8]), binary.LittleEndian, &numberToReturn)
// 提取集合名 // 提取集合名
collEnd := bytes.IndexByte(remaining, 0) // collEnd := bytes.IndexByte(remaining, 0)
if collEnd == -1 { // if collEnd == -1 {
return nil, fmt.Errorf("collection name not null terminated") // return nil, fmt.Errorf("collection name not null terminated")
} // }
collName := string(remaining[:collEnd]) // if collEnd == 0 {
// collEnd = bytes.IndexFunc(remaining, func(r rune) bool {
// return r != 0
// })
// }
// collName := string(remaining[:collEnd])
// 解析查询条件 // 解析查询条件
queryDoc, _, err := parseBSON(remaining[collEnd+1:]) queryDoc, _, err := parseBSON(remaining[8:])
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse query conditions: %v", err) return nil, fmt.Errorf("failed to parse query conditions: %v", err)
} }
return &QueryMessage{ return &QueryMessage{
Flags: flags, Flags: flags,
DatabaseName: dbName, DatabaseName: dbName,
CollName: collName, CollName: collName,
Query: queryDoc, NumberToSkip: numberToSkip,
NumberToReturn: numberToReturn,
Query: queryDoc,
}, nil }, nil
} }
@ -265,10 +278,12 @@ func parseInsert(data []byte) (interface{}, error) {
// QueryMessage OP_QUERY消息体结构 // QueryMessage OP_QUERY消息体结构
type QueryMessage struct { type QueryMessage struct {
Flags uint32 // 查询标志 Flags uint32 // 查询标志
DatabaseName string // 数据库名称 DatabaseName string // 数据库名称
CollName string // 集合名称 CollName string // 集合名称
Query map[string]interface{} // 查询条件 NumberToSkip int32
NumberToReturn int32
Query map[string]interface{} // 查询条件
} }
// InsertMessage OP_INSERT消息体结构 // InsertMessage OP_INSERT消息体结构
@ -281,8 +296,8 @@ type InsertMessage struct {
// UpdateMessage represents an OP_UPDATE message // UpdateMessage represents an OP_UPDATE message
type UpdateMessage struct { type UpdateMessage struct {
Header Header Header Header
Body struct { Body struct {
Flags UpdateFlags Flags UpdateFlags
DatabaseName string DatabaseName string
CollName string CollName string

View File

@ -2,8 +2,22 @@ package protocol
import ( import (
"testing" "testing"
"git.pyer.club/kingecg/goaidb/log"
) )
func TestParseQuery(t *testing.T) {
_, data, err := log.ReadOpMsgFromFile("/home/kingecg/code/goaidb/opbin.log")
if err != nil {
t.Fatalf("ReadOpMsgFromFile failed: %v", err)
}
msg, err := ParseMessage(data)
if err != nil {
t.Fatalf("ParseMessage failed: %v", err)
}
log.Info(msg.Body)
}
func TestParseUpdate(t *testing.T) { func TestParseUpdate(t *testing.T) {
// 构造测试数据最小有效Update消息 // 构造测试数据最小有效Update消息
data := []byte{ data := []byte{