From 9c08d68291fd15ecd34d086f04ce14c2f44be180 Mon Sep 17 00:00:00 2001 From: kingecg Date: Fri, 6 Jun 2025 22:12:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0bson,=E7=BB=86=E5=8C=96parse?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- design/coding_style_guide.md | 23 ++++ network/server.go | 46 +++++-- protocol/bson.go | 167 ++++++++++++++++++++++++ protocol/bson_test.go | 104 +++++++++++++++ protocol/const.go | 27 ++++ protocol/parser.go | 242 ++++++++++++++++++++++++++++++----- protocol/parser_test.go | 66 ++++++++++ storage/engine.go | 130 ++++++++++++++++++- 8 files changed, 766 insertions(+), 39 deletions(-) create mode 100644 protocol/bson.go create mode 100644 protocol/bson_test.go create mode 100644 protocol/const.go create mode 100644 protocol/parser_test.go diff --git a/design/coding_style_guide.md b/design/coding_style_guide.md index 8fd234d..14923ef 100644 --- a/design/coding_style_guide.md +++ b/design/coding_style_guide.md @@ -161,6 +161,29 @@ if err != nil { } ``` +## 日志规范 +### 日志级别使用 +- **Error**:记录错误事件,使用`log.Error()`,适用于不可恢复的错误 +- **Warn**:记录警告事件,使用`log.Warn()`,表示潜在问题但不会中断流程 +- **Info**:记录重要流程事件,使用`log.Info()`,用于关键节点的状态报告 +- **Debug**:记录调试信息,使用`log.Debug()`,用于开发阶段的问题追踪 + +### 日志格式 +- 所有日志必须包含上下文信息(如模块名、操作对象) +- 错误日志必须包含错误详情 +- 关键操作应包含操作结果状态 + +```go +// 正确示例 +log.Error("存储层更新失败", "error", err, "db", dbName) +log.Warn("查询条件未命中索引", "collection", collName) +log.Info("更新操作完成", "matched", matchedCount, "modified", modifiedCount) + +// 错误示例 +log.Debug("错误不应使用debug级别") // 错误类型日志应使用Error级别 +log.Info("error", err) // 错误信息应使用Error级别并包含明确描述 +``` + ## 测试规范 - 所有关键功能必须有单元测试 - 测试用例覆盖主要分支 diff --git a/network/server.go b/network/server.go index 99bcdc3..5446d1b 100644 --- a/network/server.go +++ b/network/server.go @@ -6,8 +6,8 @@ import ( "io" "net" + "git.pyer.club/kingecg/goaidb/log" "git.pyer.club/kingecg/goaidb/protocol" - "git.pyer.club/kingecg/goaidb/query" "git.pyer.club/kingecg/goaidb/storage" ) @@ -42,7 +42,7 @@ func (s *Server) handleConnection(conn net.Conn) { n, err := conn.Read(buffer) if err != nil { if err != io.EOF { - fmt.Printf("Error reading from connection: %v\n", err) + log.Error("连接读取失败", "error", err) } return } @@ -50,22 +50,50 @@ func (s *Server) handleConnection(conn net.Conn) { // 解析MongoDB协议消息 message, err := protocol.ParseMessage(buffer[:n]) if err != nil { - fmt.Printf("Failed to parse message: %v\n", err) + log.Error("消息解析失败", "error", err) continue } - // 处理查询请求 - response, err := query.HandleQuery(message, s.storage) - if err != nil { - fmt.Printf("Query handling error: %v\n", err) - continue + var response []byte + + switch message.OpCode { + case protocol.OP_UPDATE: + updateMsg := message.Body.(*protocol.UpdateMessage) + // 序列化查询和更新文档为BSON格式 + queryBson, err := protocol.BsonMarshal(updateMsg.Query) + if err != nil { + log.Error("查询文档序列化失败", "error", err) + continue + } + + updateBson, err := protocol.BsonMarshal(updateMsg.Update) + if err != nil { + log.Error("更新文档序列化失败", "error", err) + continue + } + + err = s.storage.Update(updateMsg.DatabaseName, updateMsg.CollName, queryBson, updateBson) + if err != nil { + log.Error("存储层更新失败", "error", err) + continue + } + response = constructUpdateResponse(message) + default: + log.Warn("不支持的操作码", "opcode", message.OpCode) } // 发送响应 _, err = conn.Write(response) if err != nil { - fmt.Printf("Failed to send response: %v\n", err) + log.Error("响应发送失败", "error", err) return } } } + +// 构造简单的OP_REPLY响应 +func constructUpdateResponse(request *protocol.Message) []byte { + // 实际实现应构造完整的OP_REPLY消息 + // 这里只是一个示例,返回空文档 + return []byte{} +} diff --git a/protocol/bson.go b/protocol/bson.go new file mode 100644 index 0000000..6905d74 --- /dev/null +++ b/protocol/bson.go @@ -0,0 +1,167 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" +) + +// BSONElement represents a single BSON element +type BSONElement struct { + ElementType byte + KeyName string + Value interface{} +} + +// ParseBSON 解析BSON文档并返回映射 +func ParseBSON(data []byte) (map[string]interface{}, error) { + result, _, err := parseBSON(data) + return result, err +} + +// parseBSON 内部使用的BSON解析函数 +func parseBSON(data []byte) (map[string]interface{}, []byte, error) { + // 检查数据长度(最小需要4字节的文档长度) + if len(data) <= 4 { + return nil, data, fmt.Errorf("data too short for BSON document length") + } + + // 读取BSON文档长度 + length := int(binary.LittleEndian.Uint32(data[0:4])) + if len(data) < length { + return nil, data, fmt.Errorf("data too short for BSON document") + } + + // 创建结果映射 + doc := make(map[string]interface{}) + + // 指向当前解析位置 + pos := 4 + + // 解析BSON元素,直到遇到结束符0x00 + for pos < length { + // 检查是否有足够的数据读取元素类型 + if pos+1 > len(data) { + return nil, data, fmt.Errorf("unexpected end of data reading element type") + } + + // 读取元素类型 + elementType := data[pos] + pos++ + + // 如果是结束元素,跳出循环 + if elementType == 0x00 { + break + } + + // 读取键名(C风格字符串,以0终止) + // keyStart := pos + keyEnd := bytes.IndexByte(data[pos:], 0) + if keyEnd == -1 { + return nil, data, fmt.Errorf("key not null terminated") + } + keyEnd += pos // 调整到正确的位置 + + keyName := string(data[pos:keyEnd]) + pos = keyEnd + 1 + + // 根据元素类型解析值 + value, newPos, err := parseBSONValue(elementType, data[pos:]) + if err != nil { + return nil, data, fmt.Errorf("failed to parse value for key %s: %v", keyName, err) + } + + doc[keyName] = value + pos += newPos + } + + return doc, data[length:], nil +} + +// parseBSONValue 解析特定类型的BSON值 +func parseBSONValue(elementType byte, data []byte) (interface{}, int, error) { + switch elementType { + case 0x10: // Int32 + if len(data) < 4 { + return nil, 0, fmt.Errorf("data too short for Int32") + } + value := int32(binary.LittleEndian.Uint32(data[0:4])) + return value, 4, nil + + case 0x12: // Int64 + if len(data) < 8 { + return nil, 0, fmt.Errorf("data too short for Int64") + } + value := int64(binary.LittleEndian.Uint64(data[0:8])) + return value, 8, nil + + case 0x01: // Double + if len(data) < 8 { + return nil, 0, fmt.Errorf("data too short for Double") + } + bits := binary.LittleEndian.Uint64(data[0:8]) + value := math.Float64frombits(bits) + return value, 8, nil + + case 0x02: // String + if len(data) < 4 { + return nil, 0, fmt.Errorf("data too short for String length") + } + + // 读取字符串长度 + strLength := int(binary.LittleEndian.Uint32(data[0:4])) + if strLength < 1 { + return nil, 0, fmt.Errorf("invalid string length %d", strLength) + } + + // 检查剩余数据是否足够包含整个字符串 + if len(data)-4 < strLength { + return nil, 0, fmt.Errorf("data too short for String content") + } + + // 读取字符串内容(忽略最后的终止符) + value := string(data[4 : 4+strLength]) + return value, 4 + strLength, nil + + case 0x08: // Boolean + if len(data) < 1 { + return nil, 0, fmt.Errorf("data too short for Boolean") + } + value := data[0] != 0 + return value, 1, nil + + case 0x0A: // Null + return nil, 0, nil + + case 0x03: // EmbeddedDocument + // 解析嵌入文档 + if len(data) < 4 { + return nil, 0, fmt.Errorf("data too short for EmbeddedDocument length") + } + + // 读取嵌入文档长度 + docLength := int(binary.LittleEndian.Uint32(data[0:4])) + if len(data) < docLength { + return nil, 0, fmt.Errorf("data too short for EmbeddedDocument") + } + + // 解析嵌入文档内容 + subDoc, _, err := parseBSON(data[0:docLength]) + if err != nil { + return nil, 0, fmt.Errorf("failed to parse embedded document: %v", err) + } + + return subDoc, docLength, nil + + default: + return nil, 0, fmt.Errorf("unsupported BSON element type: 0x%02X", elementType) + } +} + +// BsonMarshal 将map转换为BSON格式的字节流 +func BsonMarshal(doc map[string]interface{}) ([]byte, error) { + // TODO: 实现实际的BSON序列化或使用现有库 + // 这里返回模拟实现 + return []byte{}, nil +} diff --git a/protocol/bson_test.go b/protocol/bson_test.go new file mode 100644 index 0000000..7e39dd9 --- /dev/null +++ b/protocol/bson_test.go @@ -0,0 +1,104 @@ +package protocol + +import ( + "strings" + "testing" +) + +func TestParseBSON_EmptyDocument(t *testing.T) { + // 构造最小有效BSON文档(仅包含长度和结束符) + data := []byte{ + 5, 0, 0, 0, // 文档长度=5字节 + 0x00, // 结束符 + } + + result, err := ParseBSON(data) + if err != nil { + t.Fatalf("ParseBSON failed: %v", err) + } + + if len(result) != 0 { + t.Errorf("Expected empty document, got %d elements", len(result)) + } +} + +func TestParseBSON_Int32(t *testing.T) { + // 构造包含Int32字段的BSON文档 + data := []byte{ + 9, 0, 0, 0, // 文档长度=9字节 + 0x10, // Int32类型 + 't', 'e', 's', 't', 0x00, // 键名"test" + 0x12, 0x34, 0x00, 0x00, // 值=0x00003412 (小端序) + 0x00, // 结束符 + } + + result, err := ParseBSON(data) + if err != nil { + t.Fatalf("ParseBSON failed: %v", err) + } + + value, ok := result["test"] + if !ok { + t.Error("Expected key 'test' not found") + return + } + + if intValue, ok := value.(int32); !ok || intValue != 0x00003412 { + t.Errorf("Expected Int32 0x00003412, got %v (%T)", value, value) + } +} + +func TestParseBSON_String(t *testing.T) { + // 构造包含字符串字段的BSON文档 + data := []byte{ + 16, 0, 0, 0, // 文档长度=22字节 + 0x02, // String类型 + 'n', 'a', 'm', 'e', 0x00, // 键名"name" + 5, 0, 0, 0, // 字符串长度=7字节(含终止符) + 'h', 'e', 'l', 'l', 'o', 0x00, // 字符串内容 + 0x00, // 结束符 + } + + result, err := ParseBSON(data) + if err != nil { + t.Fatalf("ParseBSON failed: %v", err) + } + + value, ok := result["name"] + if !ok { + t.Error("Expected key 'name' not found") + return + } + + if strValue, ok := value.(string); !ok || strValue != "hello" { + t.Errorf("Expected string 'hello', got %q (%T)", value, value) + } +} + +func TestParseBSON_ErrorCases(t *testing.T) { + // 测试数据过短的情况 + shortData := []byte{4, 0, 0, 0} // 长度为4字节的文档(仅包含长度字段) + _, err := ParseBSON(shortData) + if err == nil { + t.Error("Expected error for short data") + } + + // 测试无效元素类型 + invalidTypeData := []byte{ + 6, 0, 0, 0, // 文档长度=6字节 + 0xFF, // 无效元素类型 + 'k', 'e', 'y', 0x00, // 键名"key" + 0x00, // 结束符 + } + _, err = ParseBSON(invalidTypeData) + if err == nil { + t.Error("Expected error for invalid element type") + } else if !containsError(err, "unsupported BSON element") { + t.Errorf("Expected unsupported element type error, got %v", err) + } +} + +// Helper function to check if error message contains expected text +func containsError(err error, expected string) bool { + return err != nil && len(err.Error()) >= len(expected) && strings.Contains(err.Error(), expected) +} diff --git a/protocol/const.go b/protocol/const.go new file mode 100644 index 0000000..7df90bb --- /dev/null +++ b/protocol/const.go @@ -0,0 +1,27 @@ +package protocol + +// Update操作标志 +type UpdateFlags int32 + +const ( + // UBF_NONE 无特殊标志 + UBF_NONE UpdateFlags = 0 + // UBF_UPSERT 如果没有匹配文档则插入新文档 + UBF_UPSERT UpdateFlags = 1 << iota + // UBF_MULTI_UPDATE 更新所有匹配文档 + UBF_MULTI_UPDATE +) + +// Update操作符 +const ( + // UPDATE_OP_SET $set操作符 + UPDATE_OP_SET = "$set" + // UPDATE_OP_INC $inc操作符 + UPDATE_OP_INC = "$inc" + // UPDATE_OP_UNSET $unset操作符 + UPDATE_OP_UNSET = "$unset" + // UPDATE_OP_PUSH $push操作符 + UPDATE_OP_PUSH = "$push" + // UPDATE_OP_PULL $pull操作符 + UPDATE_OP_PULL = "$pull" +) \ No newline at end of file diff --git a/protocol/parser.go b/protocol/parser.go index 7165020..efb0725 100644 --- a/protocol/parser.go +++ b/protocol/parser.go @@ -2,17 +2,19 @@ package protocol import ( + "bytes" "encoding/binary" "fmt" ) -// Message MongoDB协议消息结构 -type Message struct { - Header Header - OpCode OpCode - OriginalBody []byte // 原始消息体(解析前) - Body interface{} // 解析后的消息体 -} +// UpdateFlags 更新操作标志位 + +const ( + // Update操作的标志位常量 + Upsert = 1 << 0 + MultiUpdate = 1 << 1 + WriteConcern = 1 << 3 // 3.x驱动已弃用 +) // Header 消息头 type Header struct { @@ -26,34 +28,110 @@ type Header struct { type OpCode int32 const ( - OP_REPLY OpCode = 1 - OP_MSG OpCode = 2 - OP_UPDATE OpCode = 2001 - OP_INSERT OpCode = 2002 - RESERVED OpCode = 2003 - OP_QUERY OpCode = 2004 - OP_GET_MORE OpCode = 2005 - OP_DELETE OpCode = 2006 - OP_KILL_CURSORS OpCode = 2007 - OP_COMMAND OpCode = 2010 - OP_COMMAND_REPLY OpCode = 2011 - OP_COMPRESSED OpCode = 2012 - OP_ENCRYPTED OpCode = 2013 + OP_REPLY OpCode = 1 + OP_MSG OpCode = 2 + OP_UPDATE OpCode = 2001 + OP_INSERT OpCode = 2002 + RESERVED OpCode = 2003 + OP_QUERY OpCode = 2004 + OP_GET_MORE OpCode = 2005 + OP_DELETE OpCode = 2006 + OP_KILL_CURSORS OpCode = 2007 + OP_COMMAND OpCode = 2010 + OP_COMMAND_REPLY OpCode = 2011 + OP_COMPRESSED OpCode = 2012 + OP_ENCRYPTED OpCode = 2013 ) +// Message MongoDB协议消息结构 +type Message struct { + Header Header + OpCode OpCode + OriginalBody []byte // 原始消息体(解析前) + Body interface{} // 解析后的消息体 +} + +// 解析更新请求 +func parseUpdate(data []byte) (interface{}, error) { + // 检查数据长度(最小需要4字节的flags + 1字节的数据库名终止符) + if len(data) < 5 { + return nil, fmt.Errorf("update message data too short") + } + + // 读取flags + flags := UpdateFlags(binary.LittleEndian.Uint32(data[0:4])) + + // 查找数据库名结束位置(C风格字符串,以0终止) + dbEnd := bytes.IndexByte(data[4:], 0) + if dbEnd == -1 { + return nil, fmt.Errorf("database name not null terminated") + } + dbEnd += 4 // 调整到正确的位置 + + // 提取数据库名称 + dbName := string(data[4:dbEnd]) + + // 剩余数据包含集合名、查询文档和更新文档 + remaining := data[dbEnd+1:] + + // 查找集合名结束位置 + collEnd := bytes.IndexByte(remaining, 0) + if collEnd == -1 { + return nil, fmt.Errorf("collection name not null terminated") + } + + // 提取集合名 + collName := string(remaining[:collEnd]) + + // 剩余数据包含查询文档和更新文档 + bsonData := remaining[collEnd+1:] + + // 解析BSON文档 + queryDoc, _, err := parseBSON(bsonData) + if err != nil { + return nil, fmt.Errorf("failed to parse query document: %v", err) + } + + // 解析更新文档 + updateDoc, _, err := parseBSON(bsonData) + if err != nil { + return nil, fmt.Errorf("failed to parse update document: %v", err) + } + + return &UpdateMessage{ + Flags: flags, + DatabaseName: dbName, + CollName: collName, + Query: queryDoc, + Update: updateDoc, + }, nil +} + // ParseMessage 解析MongoDB协议消息 func ParseMessage(data []byte) (*Message, error) { + // 最小消息长度为16字节(消息头长度) if len(data) < 16 { return nil, fmt.Errorf("data too short for message header") } - header := &Header{ - MessageLength: int32(binary.LittleEndian.Uint32(data[0:4])), - RequestID: int32(binary.LittleEndian.Uint32(data[4:8])), - ResponseTo: int32(binary.LittleEndian.Uint32(data[8:12])), - OpCode: OpCode(binary.LittleEndian.Uint32(data[12:16])), + // 验证消息长度是否完整 + messageLength := int(binary.LittleEndian.Uint32(data[0:4])) + if len(data) < messageLength { + return nil, fmt.Errorf("data too short for complete message") } + // 截取实际的消息数据(可能有多条消息) + actualData := data[:messageLength] + + // 解析消息头 + header := &Header{ + MessageLength: int32(binary.LittleEndian.Uint32(actualData[0:4])), + RequestID: int32(binary.LittleEndian.Uint32(actualData[4:8])), + ResponseTo: int32(binary.LittleEndian.Uint32(actualData[8:12])), + OpCode: OpCode(binary.LittleEndian.Uint32(actualData[12:16])), + } + + // 获取消息体 body := data[16:] // 解析特定操作码的消息体 @@ -71,6 +149,12 @@ func ParseMessage(data []byte) (*Message, error) { return nil, err } parsedBody = insert + case OP_UPDATE: + update, err := parseUpdate(body) + if err != nil { + return nil, err + } + parsedBody = update // 这里可以添加更多操作码的解析逻辑 default: // 未知操作码,保留原始数据 @@ -89,12 +173,112 @@ func ParseMessage(data []byte) (*Message, error) { func parseQuery(data []byte) (interface{}, error) { // 实现具体的查询消息解析逻辑 // 这里返回原始数据作为占位符 - return data, nil + if len(data) < 4 { + return nil, fmt.Errorf("query data too short") + } + + // 示例:读取查询标志 + flags := binary.LittleEndian.Uint32(data[0:4]) + + // 提取数据库名称 + dbEnd := bytes.IndexByte(data[4:], 0) + if dbEnd == -1 { + return nil, fmt.Errorf("database name not null terminated") + } + dbName := string(data[4 : dbEnd+4]) + + // 剩余数据包含集合名和查询条件 + remaining := data[dbEnd+5:] // 跳过终止符 + + // 提取集合名 + collEnd := bytes.IndexByte(remaining, 0) + if collEnd == -1 { + return nil, fmt.Errorf("collection name not null terminated") + } + collName := string(remaining[:collEnd]) + + // 解析查询条件 + queryDoc, _, err := parseBSON(remaining[collEnd+1:]) + if err != nil { + return nil, fmt.Errorf("failed to parse query conditions: %v", err) + } + + return &QueryMessage{ + Flags: flags, + DatabaseName: dbName, + CollName: collName, + Query: queryDoc, + }, nil } // 解析插入请求 func parseInsert(data []byte) (interface{}, error) { // 实现具体的插入消息解析逻辑 - // 这里返回原始数据作为占位符 - return data, nil -} \ No newline at end of file + if len(data) < 4 { + return nil, fmt.Errorf("insert data too short") + } + + // 示例:读取插入标志 + flags := binary.LittleEndian.Uint32(data[0:4]) + + // 提取数据库名称 + dbEnd := bytes.IndexByte(data[4:], 0) + if dbEnd == -1 { + return nil, fmt.Errorf("database name not null terminated") + } + dbName := string(data[4 : dbEnd+4]) + + // 剩余数据包含集合名和文档数据 + remaining := data[dbEnd+5:] // 跳过终止符 + + // 提取集合名 + collEnd := bytes.IndexByte(remaining, 0) + if collEnd == -1 { + return nil, fmt.Errorf("collection name not null terminated") + } + collName := string(remaining[:collEnd]) + + // 解析文档数据 + documents := make([]map[string]interface{}, 0) + rest := remaining[collEnd+1:] + for len(rest) > 0 { + doc, remainingData, err := parseBSON(rest) + if err != nil { + return nil, fmt.Errorf("failed to parse document: %v", err) + } + documents = append(documents, doc) + rest = remainingData + } + + return &InsertMessage{ + Flags: flags, + DatabaseName: dbName, + CollName: collName, + Documents: documents, + }, nil +} + +// QueryMessage OP_QUERY消息体结构 +type QueryMessage struct { + Flags uint32 // 查询标志 + DatabaseName string // 数据库名称 + CollName string // 集合名称 + Query map[string]interface{} // 查询条件 +} + +// InsertMessage OP_INSERT消息体结构 +type InsertMessage struct { + Flags uint32 // 插入标志 + DatabaseName string // 数据库名称 + CollName string // 集合名称 + Documents []map[string]interface{} // 要插入的文档 +} + +// UpdateMessage OP_UPDATE消息体结构 +type UpdateMessage struct { + Flags UpdateFlags // 更新标志 + DatabaseName string // 数据库名称 + CollName string // 集合名称 + Query map[string]interface{} // 查询条件 + Update map[string]interface{} // 更新操作 +} diff --git a/protocol/parser_test.go b/protocol/parser_test.go new file mode 100644 index 0000000..9346ef6 --- /dev/null +++ b/protocol/parser_test.go @@ -0,0 +1,66 @@ +package protocol + +import ( + "testing" +) + +func TestParseUpdate(t *testing.T) { + // 构造测试数据(最小有效Update消息) + data := []byte{ + // Header + 16, 0, 0, 0, // MessageLength (16 bytes) + 1, 0, 0, 0, // RequestID + 0, 0, 0, 0, // ResponseTo + 209, 7, 0, 0, // OpCode (OP_UPDATE) + + // UpdateFlags (Upsert=1) + 1, 0, 0, 0, + + // DatabaseName "test\x00collection\x00" + 't', 'e', 's', 't', 0, + + // CollectionName "coll\x00" + 'c', 'o', 'l', 'l', 0, + + // Query文档(最小BSON文档) + 5, 0, 0, 0, 0, // BSON长度 + 空文档 + + // Update文档 + 5, 0, 0, 0, 0, // BSON长度 + 空文档 + } + + msg, err := ParseMessage(data) + if err != nil { + t.Fatalf("ParseMessage failed: %v", err) + } + + if msg.OpCode != OP_UPDATE { + t.Errorf("Expected OP_UPDATE, got %d", msg.OpCode) + } + + updateMsg, ok := msg.Body.(*UpdateMessage) + if !ok { + t.Fatalf("Expected UpdateMessage, got %T", msg.Body) + } + + if updateMsg.Flags != Upsert { + t.Errorf("Expected Upsert flag, got %v", updateMsg.Flags) + } + + if updateMsg.DatabaseName != "test" { + t.Errorf("Expected database 'test', got '%s'", updateMsg.DatabaseName) + } + + if updateMsg.CollName != "coll" { + t.Errorf("Expected collection 'coll', got '%s'", updateMsg.CollName) + } +} + +func TestParseInvalidUpdate(t *testing.T) { + // 测试短数据包 + shortData := []byte{1, 0, 0, 0} // 长度不足4字节 + _, err := parseUpdate(shortData) // 使用空标识符丢弃未使用的返回值 + if err == nil { + t.Error("Expected error for short data") + } +} \ No newline at end of file diff --git a/storage/engine.go b/storage/engine.go index c3a62d9..dd59e15 100644 --- a/storage/engine.go +++ b/storage/engine.go @@ -3,8 +3,30 @@ package storage import ( "fmt" + "git.pyer.club/kingecg/goaidb/log" + "git.pyer.club/kingecg/goaidb/protocol" + "encoding/binary" ) +// matchesQuery 是一个简单的查询匹配函数(实际应使用更复杂的逻辑) +func matchesQuery(doc, query map[string]interface{}) bool { + // 这里应该实现实际的查询匹配逻辑 + // 当前只是一个简单的存根实现 + if query == nil { + return true // 如果没有查询条件,则匹配所有文档 + } + + // 遍历查询条件的所有字段 + for key, value := range query { + // 检查文档是否包含该字段且值匹配 + if docValue, exists := doc[key]; !exists || docValue != value { + return false + } + } + + return true +} + // StorageEngine 存储引擎接口 type StorageEngine interface { // 数据库操作 @@ -140,10 +162,116 @@ func (e *memoryEngine) Query(dbName, collName string, query []byte) ([][]byte, e } func (e *memoryEngine) Update(dbName, collName string, query, update []byte) error { - // TODO: 实现更新逻辑 + // 记录调试日志 + log.Debug("开始执行更新操作", "db", dbName, "collection", collName) + + // 获取集合 + db, exists := e.databases[dbName] + if !exists { + log.Warn("数据库不存在", "db", dbName) + return fmt.Errorf("database %s does not exist", dbName) + } + + coll, exists := db.collections[collName] + if !exists { + log.Warn("集合不存在", "db", dbName, "collection", collName) + return fmt.Errorf("collection %s does not exist in database %s", collName, dbName) + } + + // 解析查询和更新文档 + queryDoc, err := protocol.ParseBSON(query) + if err != nil { + log.Error("查询文档解析失败", "error", err, "db", dbName, "collection", collName) + return fmt.Errorf("failed to parse query document: %v", err) + } + + updateDoc, err := protocol.ParseBSON(update) + if err != nil { + log.Error("更新文档解析失败", "error", err, "db", dbName, "collection", collName) + return fmt.Errorf("failed to parse update document: %v", err) + } + + // 执行更新操作 + matchedCount := 0 + modifiedCount := 0 + + for i := range coll.data { + // 解析当前文档 + doc, err := protocol.ParseBSON(coll.data[i]) + if err != nil { + log.Warn("文档解析失败", "error", err, "index", i, "db", dbName, "collection", collName) + continue + } + + // 检查是否匹配查询条件 + match := matchesQuery(doc, queryDoc) + if match { + matchedCount++ + + // 应用更新操作 - 简单实现$set操作 + if setOp, ok := updateDoc["$set"].(map[string]interface{}); ok { + // 实际应解析文档并应用更新,这里只是简单示例 + for key, value := range setOp { + // 在实际实现中,需要正确修改文档内容 + applySetOperation(doc, key, value) + log.Debug("应用$set操作", "key", key, "value", value, "db", dbName, "collection", collName) + } + } + + // 将更新后的文档重新序列化 + updatedData, err := bsonMarshal(doc) + if err != nil { + log.Warn("文档序列化失败", "error", err, "index", i, "db", dbName, "collection", collName) + continue + } + + // 替换数据中的文档 + coll.data[i] = updatedData + modifiedCount++ + } + } + + log.Info("更新操作完成", "matched", matchedCount, "modified", modifiedCount, "db", dbName, "collection", collName) return nil } +// applySetOperation 应用$set操作到文档 +func applySetOperation(doc map[string]interface{}, key string, value interface{}) { + // 简单实现单层字段设置 + doc[key] = value +} + +// bsonMarshal 将map转换为BSON格式的字节流 +func bsonMarshal(doc map[string]interface{}) ([]byte, error) { + // 使用协议包中的BSON序列化功能 + return protocol.BsonMarshal(doc) +} + +// parseBSON 解析BSON文档(应移至单独的bson包或使用现有库) +func parseBSON(data []byte) (map[string]interface{}, []byte, error) { + // 实际实现应该解析BSON格式的数据 + // 这里返回一个模拟实现 + if len(data) < 4 { + log.Warn("数据过短,无法读取BSON文档长度") + return nil, data, fmt.Errorf("data too short for BSON document length") + } + + // 读取BSON文档长度 + length := int(binary.LittleEndian.Uint32(data[0:4])) + if len(data) < length { + log.Warn("数据过短,无法读取完整BSON文档", "required", length, "available", len(data)) + return nil, data, fmt.Errorf("data too short for BSON document") + } + + // TODO: 实际解析BSON文档内容 + + // 返回空文档作为占位符,剩余数据和nil错误 + result := make(map[string]interface{}) + log.Debug("成功解析BSON文档", "length", length, "remaining", len(data)-length) + return result, data[length:], nil +} + + func (e *memoryEngine) Delete(dbName, collName string, query []byte) error { // TODO: 实现删除逻辑 return nil