316 lines
7.7 KiB
Go
316 lines
7.7 KiB
Go
package protocol
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/binary"
|
||
"fmt"
|
||
"math"
|
||
"strings"
|
||
)
|
||
|
||
// BSONElement represents a single BSON element
|
||
type BSONElement struct {
|
||
ElementType byte
|
||
KeyName string
|
||
Value interface{}
|
||
}
|
||
|
||
// ParseBSON 解析BSON文档并返回映射
|
||
func ParseBSON(data []byte) (map[string]interface{}, error) {
|
||
result, _, err := parseBSON(data)
|
||
return result, err
|
||
}
|
||
|
||
// parseBSON 内部使用的BSON解析函数
|
||
func parseBSON(data []byte) (map[string]interface{}, []byte, error) {
|
||
// 检查数据长度(最小需要4字节的文档长度)
|
||
if len(data) <= 4 {
|
||
return nil, data, fmt.Errorf("data too short for BSON document length")
|
||
}
|
||
|
||
// 读取BSON文档长度
|
||
length := int(binary.LittleEndian.Uint32(data[0:4]))
|
||
if len(data) < length {
|
||
return nil, data, fmt.Errorf("data too short for BSON document")
|
||
}
|
||
|
||
// 创建结果映射
|
||
doc := make(map[string]interface{})
|
||
|
||
// 指向当前解析位置
|
||
pos := 4
|
||
|
||
// 解析BSON元素,直到遇到结束符0x00
|
||
for pos < length {
|
||
// 检查是否有足够的数据读取元素类型
|
||
if pos+1 > len(data) {
|
||
return nil, data, fmt.Errorf("unexpected end of data reading element type")
|
||
}
|
||
|
||
// 读取元素类型
|
||
elementType := data[pos]
|
||
pos++
|
||
|
||
// 如果是结束元素,跳出循环
|
||
if elementType == 0x00 {
|
||
break
|
||
}
|
||
|
||
// 读取键名(C风格字符串,以0终止)
|
||
// keyStart := pos
|
||
keyEnd := bytes.IndexByte(data[pos:], 0)
|
||
if keyEnd == -1 {
|
||
return nil, data, fmt.Errorf("key not null terminated")
|
||
}
|
||
keyEnd += pos // 调整到正确的位置
|
||
|
||
keyName := string(data[pos:keyEnd])
|
||
pos = keyEnd + 1
|
||
|
||
// 根据元素类型解析值
|
||
value, newPos, err := parseBSONValue(elementType, data[pos:], 0)
|
||
if err != nil {
|
||
return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err)
|
||
}
|
||
|
||
doc[keyName] = value
|
||
pos += newPos
|
||
}
|
||
|
||
return doc, data[length:], nil
|
||
}
|
||
|
||
// parseBSONValue 解析特定类型的BSON值
|
||
func parseBSONValue(elementType byte, data []byte, pos int) (interface{}, int, error) {
|
||
switch elementType {
|
||
case 0x10: // Int32
|
||
if len(data) < 4 {
|
||
return nil, 0, fmt.Errorf("data too short for Int32")
|
||
}
|
||
value := int32(binary.LittleEndian.Uint32(data[0:4]))
|
||
return value, 4, nil
|
||
|
||
case 0x12: // Int64
|
||
if len(data) < 8 {
|
||
return nil, 0, fmt.Errorf("data too short for Int64")
|
||
}
|
||
value := int64(binary.LittleEndian.Uint64(data[0:8]))
|
||
return value, 8, nil
|
||
|
||
case 0x01: // Double
|
||
if len(data) < 8 {
|
||
return nil, 0, fmt.Errorf("data too short for Double")
|
||
}
|
||
bits := binary.LittleEndian.Uint64(data[0:8])
|
||
value := math.Float64frombits(bits)
|
||
return value, 8, nil
|
||
|
||
case 0x02: // String
|
||
if len(data) < 4 {
|
||
return nil, 0, fmt.Errorf("data too short for String length")
|
||
}
|
||
|
||
// 读取字符串长度
|
||
strLength := int(binary.LittleEndian.Uint32(data[0:4]))
|
||
if strLength < 1 {
|
||
return nil, 0, fmt.Errorf("invalid string length %d", strLength)
|
||
}
|
||
|
||
// 检查剩余数据是否足够包含整个字符串
|
||
if len(data)-4 < strLength {
|
||
return nil, 0, fmt.Errorf("data too short for String content")
|
||
}
|
||
|
||
// 读取字符串内容(忽略最后的终止符)
|
||
value := strings.Trim(string(data[4:4+strLength]), "\x00")
|
||
return value, 4 + strLength, nil
|
||
|
||
case 0x08: // Boolean
|
||
if len(data) < 1 {
|
||
return nil, 0, fmt.Errorf("data too short for Boolean")
|
||
}
|
||
value := data[0] != 0
|
||
return value, 1, nil
|
||
|
||
case 0x0A: // Null
|
||
return nil, 0, nil
|
||
|
||
case 0x03: // EmbeddedDocument
|
||
// 解析嵌入文档
|
||
if len(data) < 4 {
|
||
return nil, 0, fmt.Errorf("data too short for EmbeddedDocument length")
|
||
}
|
||
|
||
// 读取嵌入文档长度
|
||
docLength := int(binary.LittleEndian.Uint32(data[0:4]))
|
||
if len(data) < docLength {
|
||
return nil, 0, fmt.Errorf("data too short for EmbeddedDocument")
|
||
}
|
||
|
||
// 检查是否为更新操作符(以$开头的键名)
|
||
if pos > 0 && data[pos-1] == 0x02 { // 前一个字节是字符串类型标记
|
||
// 查找前一个键名
|
||
for i := pos - 2; i >= 0; i-- {
|
||
if data[i] == 0x02 { // 找到字符串类型标记
|
||
keyLen := int(binary.LittleEndian.Uint32(data[i+1 : i+5]))
|
||
keyStart := i + 5
|
||
if keyStart+keyLen <= pos {
|
||
key := string(data[keyStart : keyStart+keyLen-1])
|
||
if len(key) > 0 && key[0] == '$' {
|
||
// 这是一个更新操作符
|
||
value := make(map[string]interface{})
|
||
value[key] = "operator_placeholder"
|
||
return value, docLength, nil
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 解析嵌入文档内容
|
||
subDoc, _, err := parseBSON(data[0:docLength])
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("failed to parse embedded document: %v", err)
|
||
}
|
||
|
||
return subDoc, docLength, nil
|
||
|
||
default:
|
||
return nil, 0, fmt.Errorf("unsupported BSON element type: 0x%02X", elementType)
|
||
}
|
||
}
|
||
|
||
// BsonMarshal 将map转换为BSON格式的字节流
|
||
func BsonMarshal(doc map[string]interface{}) ([]byte, error) {
|
||
buf := &bytes.Buffer{}
|
||
|
||
// 写入占位符长度(4字节)
|
||
buf.Write(make([]byte, 4))
|
||
|
||
// 遍历文档元素
|
||
for key, value := range doc {
|
||
// 写入元素类型和键名
|
||
elementType, err := getBSONType(value)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
buf.WriteByte(elementType)
|
||
buf.WriteString(key)
|
||
buf.WriteByte(0x00) // 键名终止符
|
||
|
||
// 写入值数据
|
||
if err := writeBSONValue(buf, elementType, value); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// 写入文档结束符
|
||
buf.WriteByte(0x00)
|
||
|
||
// 回填文档长度
|
||
length := uint32(buf.Len())
|
||
binary.LittleEndian.PutUint32(buf.Bytes(), length)
|
||
|
||
return buf.Bytes(), nil
|
||
}
|
||
|
||
// getBSONType 根据Go类型获取BSON元素类型
|
||
func getBSONType(value interface{}) (byte, error) {
|
||
switch value.(type) {
|
||
case int32:
|
||
return 0x10, nil // Int32
|
||
case int64:
|
||
return 0x12, nil // Int64
|
||
case float64:
|
||
return 0x01, nil // Double
|
||
case string:
|
||
return 0x02, nil // String
|
||
case bool:
|
||
return 0x08, nil // Boolean
|
||
case nil:
|
||
return 0x0A, nil // Null
|
||
case map[string]interface{}:
|
||
return 0x03, nil // EmbeddedDocument
|
||
default:
|
||
return 0x00, fmt.Errorf("unsupported BSON type: %T", value)
|
||
}
|
||
}
|
||
|
||
// writeBSONValue 写入BSON值数据
|
||
func writeBSONValue(buf *bytes.Buffer, elementType byte, value interface{}) error {
|
||
switch elementType {
|
||
case 0x10: // Int32
|
||
v, ok := value.(int32)
|
||
if !ok {
|
||
return fmt.Errorf("invalid type for Int32")
|
||
}
|
||
b := make([]byte, 4)
|
||
binary.LittleEndian.PutUint32(b, uint32(v))
|
||
buf.Write(b)
|
||
|
||
case 0x12: // Int64
|
||
v, ok := value.(int64)
|
||
if !ok {
|
||
return fmt.Errorf("invalid type for Int64")
|
||
}
|
||
b := make([]byte, 8)
|
||
binary.LittleEndian.PutUint64(b, uint64(v))
|
||
buf.Write(b)
|
||
|
||
case 0x01: // Double
|
||
v, ok := value.(float64)
|
||
if !ok {
|
||
return fmt.Errorf("invalid type for Double")
|
||
}
|
||
b := make([]byte, 8)
|
||
binary.LittleEndian.PutUint64(b, math.Float64bits(v))
|
||
buf.Write(b)
|
||
|
||
case 0x02: // String
|
||
v, ok := value.(string)
|
||
if !ok {
|
||
return fmt.Errorf("invalid type for String")
|
||
}
|
||
strBytes := []byte(v)
|
||
// 写入字符串长度(包含终止符)
|
||
length := uint32(len(strBytes) + 1)
|
||
b := make([]byte, 4)
|
||
binary.LittleEndian.PutUint32(b, length)
|
||
buf.Write(b)
|
||
// 写入字符串内容和终止符
|
||
buf.Write(strBytes)
|
||
buf.WriteByte(0x00)
|
||
|
||
case 0x08: // Boolean
|
||
v, ok := value.(bool)
|
||
if !ok {
|
||
return fmt.Errorf("invalid type for Boolean")
|
||
}
|
||
if v {
|
||
buf.WriteByte(0x01)
|
||
} else {
|
||
buf.WriteByte(0x00)
|
||
}
|
||
|
||
case 0x0A: // Null
|
||
// 不需要写入任何数据
|
||
|
||
case 0x03: // EmbeddedDocument
|
||
v, ok := value.(map[string]interface{})
|
||
if !ok {
|
||
return fmt.Errorf("invalid type for EmbeddedDocument")
|
||
}
|
||
subDoc, err := BsonMarshal(v)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 直接写入子文档数据(包含完整的长度信息)
|
||
buf.Write(subDoc)
|
||
|
||
default:
|
||
return fmt.Errorf("unsupported BSON element type: 0x%02X", elementType)
|
||
}
|
||
return nil
|
||
}
|