<?xml version="1.0" encoding="UTF-8"?><?xml-stylesheet href="/scripts/pretty-feed-v3.xsl" type="text/xsl"?><rss version="2.0" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:h="http://www.w3.org/TR/html4/"><channel><title>Soup&apos;s Blog</title><description>Stay hungry, stay foolish</description><link>https://astro-pure.js.org</link><item><title>大模型学习（十二）vLLM的实现原理</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-12</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-12</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 16:25:00 GMT</pubDate><content:encoded>&lt;p&gt;在大语言模型的推理过程中，生成文本并非一蹴而就，而是分为多个阶段并涉及多种优化策略。&lt;/p&gt;
&lt;p&gt;其中，Prefill（预填充） 和 Decode（解码） 构成了推理的基本流程：前者高效处理用户输入的完整提示（prompt），后者则逐个生成输出 token。而在需要生成多个候选结果时（如提高回复质量或多样性），系统会采用 Parallel Sampling（并行采样） 或 Beam Search（束搜索） 等策略。为了进一步提升效率、避免重复计算，现代推理引擎还引入了 Shared Prefix（共享前缀） 机制，允许多个生成序列复用相同的上下文缓存。&lt;/p&gt;
&lt;p&gt;这四个概念共同构成了高效、高吞吐大模型推理的核心基础，也是 vLLM、TGI、TensorRT-LLM 等先进推理系统的关键优化点。&lt;/p&gt;
&lt;h3&gt;1. &lt;strong&gt;Prefill（预填充阶段）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;定义&lt;/strong&gt;：处理用户输入的完整 prompt（提示词）的阶段。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输入&lt;/strong&gt;：完整的 prompt tokens，例如 &lt;code&gt;[&quot;Hello&quot;, &quot;how&quot;, &quot;are&quot;, &quot;you&quot;, &quot;?&quot;]&lt;/code&gt;。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;计算方式&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;所有 prompt tokens &lt;strong&gt;一次性并行计算&lt;/strong&gt;（因为彼此已知，无依赖）；&lt;/li&gt;
&lt;li&gt;生成每个 token 对应的 &lt;strong&gt;Key（K）和 Value（V）&lt;/strong&gt;，并缓存到 &lt;strong&gt;KV Cache&lt;/strong&gt; 中；&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;不生成新 token&lt;/strong&gt;，只“准备上下文”。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;特点&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;计算密集（FLOPs 高），但可高度并行；&lt;/li&gt;
&lt;li&gt;内存带宽压力大（需加载整个模型权重）；&lt;/li&gt;
&lt;li&gt;只发生一次（每个请求开始时）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;✅ &lt;strong&gt;类比&lt;/strong&gt;：就像阅读一篇文章后记住内容，为后续回答做准备。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;2. &lt;strong&gt;Decode（解码阶段）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;定义&lt;/strong&gt;：基于 prompt 和已生成的 token，&lt;strong&gt;逐个生成新 token&lt;/strong&gt; 的阶段。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输入&lt;/strong&gt;：当前所有 token（prompt + 已生成部分）；&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;计算方式&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;每次只生成 &lt;strong&gt;1 个新 token&lt;/strong&gt;（自回归）；&lt;/li&gt;
&lt;li&gt;利用 KV Cache 中已有的 K/V，&lt;strong&gt;只计算新 token 的 Q&lt;/strong&gt;，并与历史 K/V 做 attention；&lt;/li&gt;
&lt;li&gt;新 token 的 K/V 被追加到 KV Cache 中，供下一步使用。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;特点&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;无法并行&lt;/strong&gt;（每个 token 依赖前一个）；&lt;/li&gt;
&lt;li&gt;内存带宽受限（频繁读写 KV Cache）；&lt;/li&gt;
&lt;li&gt;延迟敏感（每步都要等 GPU 计算完）；&lt;/li&gt;
&lt;li&gt;占整个生成过程的大部分时间（尤其长文本）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;✅ &lt;strong&gt;类比&lt;/strong&gt;：根据记忆逐字回答问题，每说一个词都要回想之前说过什么。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;3. &lt;strong&gt;Parallel Sampling 与 Beam Search&lt;/strong&gt;&lt;/h3&gt;
&lt;p&gt;这是两种常见的&lt;strong&gt;多候选生成策略&lt;/strong&gt;：&lt;/p&gt;
&lt;h4&gt;(a) &lt;strong&gt;Parallel Sampling（并行采样）&lt;/strong&gt;&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;目的&lt;/strong&gt;：生成多个&lt;strong&gt;风格或内容不同&lt;/strong&gt;的回复，强调&lt;strong&gt;多样性&lt;/strong&gt;（例如聊天机器人提供几种可能的回答）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;工作方式&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;从&lt;strong&gt;完全相同的 prompt&lt;/strong&gt; 出发；&lt;/li&gt;
&lt;li&gt;同时启动多个&lt;strong&gt;彼此独立的生成过程&lt;/strong&gt;（比如 4 个 sample）；&lt;/li&gt;
&lt;li&gt;每个 sample 在每一步都&lt;strong&gt;独立采样下一个 token&lt;/strong&gt;（通常基于概率分布，如 top-p 或 temperature 控制）；&lt;/li&gt;
&lt;li&gt;因为采样具有随机性，各路径很快就会&lt;strong&gt;分叉&lt;/strong&gt;，生成完全不同后续。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;KV Cache 管理&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;虽然起始 prompt 相同，但传统实现中&lt;strong&gt;每个 sample 维护自己完整的 KV Cache&lt;/strong&gt;（包括 prompt 部分）；&lt;/li&gt;
&lt;li&gt;→ &lt;strong&gt;即使前缀相同，也不共享&lt;/strong&gt;，造成内存冗余；&lt;/li&gt;
&lt;li&gt;（现代引擎可通过 &lt;em&gt;Shared Prefix&lt;/em&gt; 优化这一点，见下文）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;关键特点&lt;/strong&gt;：&lt;strong&gt;多条“平行宇宙”式的生成轨迹，互不影响&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;🌰 举例：输入 “讲个笑话”，并行采样可能输出：&lt;br&gt;
① “为什么程序员…？”&lt;br&gt;
② “有一天，一只鸭子走进酒吧…”&lt;br&gt;
③ “我不会讲笑话，但我知道一个 bug…”&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h4&gt;(b) &lt;strong&gt;Beam Search（束搜索）&lt;/strong&gt;&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;目的&lt;/strong&gt;：寻找&lt;strong&gt;得分最高、最合理的一条输出序列&lt;/strong&gt;，强调&lt;strong&gt;质量与确定性&lt;/strong&gt;（常用于机器翻译、摘要等任务）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;工作方式&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;维护一个固定大小的候选集，称为 &lt;strong&gt;beam&lt;/strong&gt;（例如 beam width = 4，即始终保留 4 条最佳路径）；&lt;/li&gt;
&lt;li&gt;每一步：对当前所有候选序列&lt;strong&gt;分别扩展所有可能的下一个 token&lt;/strong&gt;，计算整体得分（通常是 log-probability 累加）；&lt;/li&gt;
&lt;li&gt;然后从所有扩展结果中&lt;strong&gt;选出总分最高的 4 个&lt;/strong&gt;，作为下一步的候选；&lt;/li&gt;
&lt;li&gt;这些候选序列&lt;strong&gt;可能共享很长的前缀&lt;/strong&gt;（比如前 5 个词完全一样，第 6 个才分叉）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;KV Cache 管理&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;理论上，所有候选共享相同的 prompt 前缀；&lt;/li&gt;
&lt;li&gt;如果不优化，会重复存储相同前缀的 K/V，浪费内存；&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;理想实现应支持 Shared Prefix&lt;/strong&gt;：只存一份公共前缀，各分支仅存差异部分。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;关键特点&lt;/strong&gt;：&lt;strong&gt;不是完全独立的路径，而是协同竞争、动态剪枝的“精英小队”&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;🌰 举例：翻译 “I love you” →&lt;br&gt;
Beam 中可能暂时保留：&lt;br&gt;
① “Je t’aime”（高分）&lt;br&gt;
② “Je vous aime”（稍低分）&lt;br&gt;
③ “J’adore toi”（语法错误，后续会被淘汰）&lt;br&gt;
最终只输出得分最高的完整句子。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;4. &lt;strong&gt;Shared Prefix（共享前缀）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;问题背景&lt;/strong&gt;：在 Parallel Sampling 或 Beam Search 中，多个生成序列往往&lt;strong&gt;共享相同的 prompt 前缀&lt;/strong&gt;（甚至部分生成前缀）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;解决方案&lt;/strong&gt;：&lt;strong&gt;只存储一份公共前缀的 KV Cache&lt;/strong&gt;，各分支仅存储自己独有的后缀部分。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;技术实现&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;使用 &lt;strong&gt;树状或图状 KV Cache 结构&lt;/strong&gt;（如 Radix Tree、Block Manager in vLLM）；&lt;/li&gt;
&lt;li&gt;多个序列通过指针引用同一段 prefix cache；&lt;/li&gt;
&lt;li&gt;显著降低内存占用（尤其当 beam width 大或 sample 数多时）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;关于vLLM的详细原理讲解，可以参考&lt;a href=&quot;https://zhuanlan.zhihu.com/p/691038809&quot;&gt;博客&lt;/a&gt;。&lt;/p&gt;
&lt;p&gt;下面从 &lt;strong&gt;优点&lt;/strong&gt; 和 &lt;strong&gt;缺点/局限性&lt;/strong&gt; 两个方面总结 vLLM 的核心特性：&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;✅ 优点&lt;/h3&gt;
&lt;p&gt;| 优势 | 说明 |
|------|------|
| &lt;strong&gt;1. 极高的吞吐量和低延迟&lt;/strong&gt; | 得益于 &lt;strong&gt;PagedAttention&lt;/strong&gt; 技术，vLLM 能高效管理 KV Cache 内存，减少内存碎片，吞吐量可达 Hugging Face Transformers 的 &lt;strong&gt;10–24 倍&lt;/strong&gt;。 |
| &lt;strong&gt;2. 支持 Shared Prefix（前缀共享）&lt;/strong&gt; | 多个请求若具有相同 prompt（如 system prompt），可共享 KV Cache，大幅节省显存，特别适合聊天机器人等场景。 |
| &lt;strong&gt;3. 高效支持连续批处理&lt;/strong&gt;（Continuous Batching） | 动态合并不同长度的请求进行批处理，避免 Decode 阶段因短请求等待长请求而空转，显著提升 GPU 利用率。 |
| &lt;strong&gt;4. 兼容主流开源模型&lt;/strong&gt; | 原生支持 Llama、Llama2/3、Mistral、Qwen、Baichuan、Gemma、Phi 等数十种架构，开箱即用。 |
| &lt;strong&gt;5. 支持多种生成策略&lt;/strong&gt; | 包括 greedy search、sampling、beam search、parallel sampling，并支持 logprobs、stop tokens、repetition penalty 等高级功能。 |
| &lt;strong&gt;6. 易于部署和集成&lt;/strong&gt; | 提供 OpenAI 兼容的 API 接口，可无缝替换 OpenAI 调用；也支持与 LangChain、LlamaIndex 等框架集成。 |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;❌ 缺点 / 局限性&lt;/h3&gt;
&lt;p&gt;| 局限 | 说明 |
|------|------|
| &lt;strong&gt;1. 不支持训练或微调&lt;/strong&gt; | vLLM 仅用于&lt;strong&gt;推理&lt;/strong&gt;，不能用于 LoRA 微调、QLoRA 或全参数训练（需配合其他框架如 Hugging Face PEFT + DeepSpeed）。 |
| &lt;strong&gt;2. 对超长上下文支持有限&lt;/strong&gt; | 虽然支持 RoPE 插值等扩展上下文方法，但 PagedAttention 在极端长序列（如 &gt; 100K tokens）下仍可能面临显存或性能瓶颈。 |
| &lt;strong&gt;3. 内存节省依赖请求模式&lt;/strong&gt; | Shared Prefix 和 PagedAttention 的优势在&lt;strong&gt;高并发、重复 prompt&lt;/strong&gt; 场景下最明显；单请求或完全不同的 prompt 下收益较小。 |
| &lt;strong&gt;4. 自定义模型适配需开发工作&lt;/strong&gt; | 若使用非主流或自研模型，需手动实现模型加载器和 attention 逻辑，有一定工程门槛。 |
| &lt;strong&gt;5. 暂不支持多模态模型&lt;/strong&gt; | 当前主要面向纯文本 LLM，对 LLaVA、Qwen-VL 等多模态模型支持有限（社区正在推进中）。 |
| &lt;strong&gt;6. 资源占用仍高于量化方案&lt;/strong&gt; | 虽比原生 HF 高效，但相比 &lt;strong&gt;GGUF（llama.cpp）&lt;/strong&gt; 或 &lt;strong&gt;AWQ/GPTQ 量化模型&lt;/strong&gt;，vLLM 默认运行 FP16/BF16 模型，显存占用更高。 |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;📌 总结一句话：&lt;/h3&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;vLLM 是目前最高效的开源 LLM 推理引擎之一，通过 PagedAttention 和连续批处理等创新，在保持模型精度的同时极大提升了吞吐与资源利用率，特别适合高并发在线服务；但它专注于推理，不替代训练框架，且对硬件资源仍有较高要求。&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（十一）LoRA与QLoRA</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-11/llm_blogs-11</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-11/llm_blogs-11</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 16:20:00 GMT</pubDate><content:encoded>&lt;h1&gt;LoRA&lt;/h1&gt;
&lt;p&gt;LoRA（Low-Rank Adaptation，低秩适配）是一种高效微调大语言模型（LLM）的技术。其核心思想是：&lt;strong&gt;冻结原始预训练模型的所有参数，仅在关键模块（如注意力层的权重矩阵）旁路添加可训练的低秩分解矩阵&lt;/strong&gt;。通过这种方式，LoRA 仅需训练极少量参数，就能在下游任务上达到接近全参数微调的性能，同时显著节省显存、计算资源和存储成本。因其高效、灵活且易于集成，LoRA 已成为大模型适配（如指令微调、领域迁移）的主流方法之一。
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
如上图，我们在进行模型全量微调时，是对左边的（&lt;code&gt;1024*512&lt;/code&gt;）的矩阵进行微调，微调过程是：通过反向传播会得到与左边矩阵同样大小的梯度矩阵，然后对原矩阵进行更新。当模型非常大时，更新的参数量也非常大。所以LoRA通过将这个矩阵进行分解成两个小矩阵进行训练，并冻结原来的矩阵，最后将训练好的两个小矩阵进行矩阵乘法运算后与原矩阵相加，从而实现了更新原矩阵的效果。&lt;/p&gt;
&lt;p&gt;通过这种方式，原来需要训练的参数量为&lt;code&gt;1024*512=524288&lt;/code&gt;降低到了只需要训练&lt;code&gt;1024*32+32*512=16384&lt;/code&gt;个参数，显著降低训练的参数量！当然在分解矩阵时中间的维度不一定要是&lt;code&gt;32&lt;/code&gt;，也可以是其他数，可以通过参数&lt;code&gt;r&lt;/code&gt;指定。&lt;/p&gt;
&lt;p&gt;推广之后如下：
&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
注意这里&lt;code&gt;r&lt;/code&gt;是远小于&lt;code&gt;N&lt;/code&gt;和&lt;code&gt;D&lt;/code&gt;的。&lt;/p&gt;
&lt;p&gt;接下来结合LoRA的公式来看：
&lt;img src=&quot;3.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
注意，在分解得到的两个矩阵中，有一个矩阵需要全零初始化，另一个就是随机初始化。在公式中除了控制分解矩阵大小的参数&lt;code&gt;r&lt;/code&gt;，还有一个超参数&lt;code&gt;α&lt;/code&gt;，称为缩放因子，可以通过调整&lt;code&gt;α&lt;/code&gt;来控制权重更新强度。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么 A 矩阵可以初始化为零？&lt;/strong&gt;
LoRA 的低秩更新项为：$\Delta W = B A$&lt;/p&gt;
&lt;p&gt;其中： $B \in \mathbb{R}^{N \times r}$，$A \in \mathbb{R}^{r \times D}$，因此 $\Delta W \in \mathbb{R}^{N \times D}$，与原始权重维度一致。&lt;/p&gt;
&lt;p&gt;如果将 $A$ 初始化为全零矩阵（即 $A = \mathbf{0}$），而 $B$ 随机初始化（非零），则初始时：
$$\Delta W = B A = B \cdot \mathbf{0} = \mathbf{0}$$
此时模型的有效权重为：
$$
W_{\text{effective}} = W_{\text{original}} + \Delta W = W_{\text{original}}
$$
✅ 这意味着模型的初始行为与原始预训练模型完全一致，不会因适配器引入扰动。&lt;/p&gt;
&lt;p&gt;随着训练进行，$B$ 的梯度可通过反向传播更新。具体来看：&lt;/p&gt;
&lt;p&gt;设损失函数为 $\mathcal{L}$，其对 $\Delta W$ 的梯度为 $G = \frac{\partial \mathcal{L}}{\partial (\Delta W)} \in \mathbb{R}^{N \times D}$。
根据链式法则：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;对 $B$ 的梯度为：  $\frac{\partial \mathcal{L}}{\partial B} = A^\top G$&lt;/li&gt;
&lt;li&gt;对 $A$ 的梯度为：$\frac{\partial \mathcal{L}}{\partial A} = G B^\top$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;在初始时刻，虽然 $B = \mathbf{0}$，但只要 $A \neq \mathbf{0}$ 且 $G \neq \mathbf{0}$，就有：$\frac{\partial \mathcal{L}}{\partial B} = A^\top G \neq \mathbf{0}$
✅ 因此，$B$ 的梯度是非零的，可以通过优化器进行更新，逐步从零变为有意义的值。&lt;/p&gt;
&lt;p&gt;有两个图可以直观展示LoRA微调前后的对比：
&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
使用LoRA微调时的attention层QKV权重演示：
&lt;img src=&quot;5.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;img src=&quot;6.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h1&gt;QLoRA&lt;/h1&gt;
&lt;p&gt;QLoRA（Quantized Low-Rank Adaptation）是一种结合量化（Quantization）与LoRA微调的高效大模型适配方法。&lt;/p&gt;
&lt;p&gt;其核心思想是：首先将预训练大模型的权重离线量化为整数，大幅降低显存占用；然后在反向传播时，仅对LoRA适配器（即低秩矩阵A和B）进行训练，他们一个float16或者bfloat16精度保存和更新。&lt;/p&gt;
&lt;p&gt;通过这种方式，QLoRA 能在消费级 GPU（如 24GB 显存的 RTX 4090）上微调数十亿甚至上百亿参数的大语言模型，且性能几乎与全量微调或标准 LoRA 相当，显著降低了大模型微调的硬件门槛和成本。&lt;/p&gt;
&lt;p&gt;如图：
&lt;img src=&quot;7.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h1&gt;使用场景&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;8.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（十）注意力机制之MHA、MQA、GQA</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-10/llm_blogs-10</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-10/llm_blogs-10</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 16:15:00 GMT</pubDate><content:encoded>&lt;p&gt;自 Transformer 架构提出以来，多头注意力（&lt;code&gt;Multi-Head Attention, MHA&lt;/code&gt;）凭借其强大的并行建模能力成为标准配置。然而，随着模型规模不断扩大、上下文长度持续增长，MHA 在推理阶段暴露出显著的内存与计算瓶颈——尤其是其庞大的 &lt;code&gt;KV Cache&lt;/code&gt; 开销，严重制约了部署效率。&lt;/p&gt;
&lt;p&gt;为平衡模型性能与推理成本，研究者相继提出了多种注意力变体：多查询注意力（&lt;code&gt;Multi-Query Attention, MQA&lt;/code&gt;）通过共享单一 &lt;code&gt;Key&lt;/code&gt; 和 &lt;code&gt;Value&lt;/code&gt; 头大幅压缩缓存，显著提升生成速度；而分组查询注意力（&lt;code&gt;Grouped-Query Attention, GQA&lt;/code&gt;）则在 &lt;code&gt;MHA&lt;/code&gt; 与 &lt;code&gt;MQA&lt;/code&gt; 之间取得折中，在几乎不损失模型质量的前提下，有效降低内存占用并加速推理。&lt;/p&gt;
&lt;p&gt;下面简单介绍一下这三种注意力机制：&lt;/p&gt;
&lt;h1&gt;多头注意力机制（MHA）&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
多头注意力机制（&lt;code&gt;Multi-Head Attention, MHA&lt;/code&gt;）并不是简单地将 &lt;code&gt;Q、K、V&lt;/code&gt; 复制 &lt;code&gt;n_heads&lt;/code&gt; 次，而是通过不同的可学习线性投影，将输入映射到 &lt;code&gt;n_heads&lt;/code&gt; 个独立的子空间中，每个头在自己的子空间里并行计算注意力。&lt;/p&gt;
&lt;p&gt;由于每一层&lt;code&gt;MHA&lt;/code&gt;需要缓存&lt;code&gt;n_heads&lt;/code&gt;个&lt;code&gt;Key&lt;/code&gt;向量和&lt;code&gt;n_heads&lt;/code&gt;个&lt;code&gt;Value&lt;/code&gt;向量，总共是&lt;code&gt;2*n_heads&lt;/code&gt;个/层。&lt;/p&gt;
&lt;h3&gt;✅ 优点&lt;/h3&gt;
&lt;p&gt;| 优点 | 说明 |
|------|------|
| &lt;strong&gt;1. 捕获多样化的依赖关系&lt;/strong&gt; | 不同头可以关注不同类型的模式（例如：一个头关注语法结构，另一个头关注语义角色，再一个头关注指代关系） |
| &lt;strong&gt;2. 增强模型表达能力&lt;/strong&gt; | 相当于集成多个“弱注意力模型”，提升整体泛化能力 |
| &lt;strong&gt;3. 并行计算高效&lt;/strong&gt; | 所有头可同时计算，充分利用 GPU 并行能力 |
| &lt;strong&gt;4. 提高训练稳定性&lt;/strong&gt; | 多视角学习有助于缓解梯度消失/爆炸问题 |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;❌ 缺点&lt;/h3&gt;
&lt;p&gt;| 缺点 | 说明 |
|------|------|
| &lt;strong&gt;1. 推理时 KV Cache 内存开销大&lt;/strong&gt; | 每个头都需要缓存自己的 K 和 V → KV Cache 大小与 &lt;code&gt;n_heads&lt;/code&gt; 成正比（例如 Llama-7B：32 heads → 缓存是 MQA 的 32 倍） |
| &lt;strong&gt;2. 计算量增加&lt;/strong&gt; | 虽然可并行，但总 FLOPs 高于单头 |
| &lt;strong&gt;3. 对长上下文不友好&lt;/strong&gt; | KV Cache 随序列长度线性增长，多头进一步放大内存压力，限制最大上下文长度 |
| &lt;strong&gt;4. 可能存在冗余头&lt;/strong&gt; | 研究发现部分头学习到相似模式，存在参数浪费 |&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;多查询注意力机制（MQA）&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;code&gt;MQA&lt;/code&gt;（&lt;code&gt;Multi-Query Attention&lt;/code&gt;，多查询注意力） 是对标准多头注意力（&lt;code&gt;MHA&lt;/code&gt;）的一种高效改进，主要用于降低大语言模型在推理阶段的内存和计算开销。它的核心思想是，保留多个&lt;code&gt;Query&lt;/code&gt;头（每个头都有自己的查询表示），但所有头共享同一个&lt;code&gt;Key&lt;/code&gt;和同一个&lt;code&gt;Value&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;这里我认为图中的KV Cache计算错了，应该是&lt;code&gt;2&lt;/code&gt;个/层（一个&lt;code&gt;Key&lt;/code&gt;和一个&lt;code&gt;Value&lt;/code&gt;）。&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;✅ 优点&lt;/h3&gt;
&lt;p&gt;| 优点 | 说明 |
|------|------|
| &lt;strong&gt;1. 显著降低 KV Cache 内存占用&lt;/strong&gt; | 所有注意力头共享同一个 Key 和 Value，每层仅需缓存 1 个 K 和 1 个 V，KV Cache 大小与 &lt;code&gt;n_heads&lt;/code&gt; 无关（例如 Llama-7B 使用 MQA 可将缓存减少至 MHA 的 1/32） |
| &lt;strong&gt;2. 提升推理速度和吞吐量&lt;/strong&gt; | 减少显存读写和带宽压力，生成新 token 时延迟更低，尤其在长序列场景下优势明显 |
| &lt;strong&gt;3. 更适合长上下文部署&lt;/strong&gt; | 因内存开销小，更容易支持数万 token 的上下文窗口，提升实际应用可行性 |
| &lt;strong&gt;4. 训练开销几乎不变&lt;/strong&gt; | 仅减少了 K/V 投影参数数量，训练时计算量和收敛性基本不受影响 |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;❌ 缺点&lt;/h3&gt;
&lt;p&gt;| 缺点 | 说明 |
|------|------|
| &lt;strong&gt;1. 表达能力略有下降&lt;/strong&gt; | 所有头被迫使用相同的 Key 和 Value 视角，无法像 MHA 那样从多个语义子空间独立建模依赖关系 |
| &lt;strong&gt;2. 在复杂任务中可能性能受损&lt;/strong&gt; | 对于高度依赖细粒度注意力模式的任务（如机器翻译、结构化推理），可能出现轻微准确率下降 |
| &lt;strong&gt;3. 不适用于编码器-heavy 架构&lt;/strong&gt; | MQA 主要优化解码器自回归生成，在需要双向上下文建模的编码器中收益有限甚至有害 |
| &lt;strong&gt;4. 注意力多样性受限&lt;/strong&gt; | 由于共享 K/V，不同头之间的注意力分布趋于相似，削弱了“多视角”学习的优势 |&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;分组查询注意力机制（GQA）&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;3.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;code&gt;GQA&lt;/code&gt;（&lt;code&gt;Grouped-Query Attention&lt;/code&gt;，分组查询注意力）是一种在多头注意力（&lt;code&gt;MHA&lt;/code&gt;）之间的折中方案。它将多个查询头（&lt;code&gt;Query heads&lt;/code&gt;）分成若干组，每组共享同一个 &lt;code&gt;Key&lt;/code&gt; 和 &lt;code&gt;Value&lt;/code&gt;，而不是像 &lt;code&gt;MHA&lt;/code&gt; 那样每个头都拥有独立的 &lt;code&gt;K/V&lt;/code&gt;，也不像 &lt;code&gt;MQA&lt;/code&gt; 那样所有头共享同一对 &lt;code&gt;K/V&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;对于每层&lt;code&gt;GQA&lt;/code&gt;，KV Cache取决于&lt;code&gt;n_heads&lt;/code&gt;和&lt;code&gt;g（分组数）&lt;/code&gt;，&lt;code&gt;2*n_heads/g&lt;/code&gt;个/层。例如，在 32 个头的模型中，可将其分为 8 组，每组 4 个 Query 头共享 1 个 K 和 1 个 V，从而将 KV Cache 的大小从 MHA 的 64 个张量减少到 16 个。&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;✅ 优点&lt;/h3&gt;
&lt;p&gt;| 优点 | 说明 |
|------|------|
| &lt;strong&gt;1. 在性能与效率之间取得良好平衡&lt;/strong&gt; | 相比 MQA 保留了更强的多头表达能力，相比 MHA 大幅降低推理开销，实测在多数任务上性能几乎无损 |
| &lt;strong&gt;2. 显著减少 KV Cache 内存占用&lt;/strong&gt; | 将 &lt;code&gt;n_heads&lt;/code&gt; 个 Query 头分组，每组共享一个 Key 和 Value，KV Cache 大小从 &lt;code&gt;2 × n_heads&lt;/code&gt; 降至 &lt;code&gt;2 × n_groups&lt;/code&gt;（例如 Llama-3-8B：32 heads → 8 groups，缓存减少至 1/4） |
| &lt;strong&gt;3. 提升长上下文推理可行性&lt;/strong&gt; | 更低的显存占用使模型更容易支持数万 token 的上下文长度，适合实际部署 |
| &lt;strong&gt;4. 兼容现有训练流程&lt;/strong&gt; | 无需修改训练目标或架构设计，仅调整注意力头的分组方式，训练稳定性与 MHA 相当 |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;❌ 缺点&lt;/h3&gt;
&lt;p&gt;| 缺点 | 说明 |
|------|------|
| &lt;strong&gt;1. 仍存在一定的 KV Cache 开销&lt;/strong&gt; | 虽优于 MHA，但缓存大小仍高于 MQA（例如 8 组 vs 1 组），在极端内存受限场景下不如 MQA 轻量 |
| &lt;strong&gt;2. 分组策略需人工设计&lt;/strong&gt; | &lt;code&gt;n_groups&lt;/code&gt; 是超参数，需在模型设计阶段确定，不同任务可能需要调优，缺乏完全自适应机制 |
| &lt;strong&gt;3. 组内头多样性受限&lt;/strong&gt; | 同一组内的多个 Query 头共享相同的 K/V，导致组内注意力模式趋同，略微削弱多视角建模能力 |
| &lt;strong&gt;4. 实现复杂度略高&lt;/strong&gt; | 相比 MHA 和 MQA，需要额外处理分组逻辑，在底层推理引擎中需专门优化（如 vLLM、TensorRT-LLM） |&lt;/p&gt;
&lt;h1&gt;对比&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
可以看到单从精度效果来看&lt;code&gt;MHA-XXL&lt;/code&gt;是最好的，单从推理速度来看&lt;code&gt;MQA-XXL&lt;/code&gt;是最好的，综合来看&lt;code&gt;GQA-XXL&lt;/code&gt;是最好的。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（九）KV Cache</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-9/llm_blogs-9</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-9/llm_blogs-9</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 16:10:00 GMT</pubDate><content:encoded>&lt;p&gt;本节内容参考&lt;a href=&quot;https://www.bilibili.com/video/BV17CPkeEEzk?spm_id_from=333.788.recommend_more_video.-1&amp;#x26;trackid=web_related_0.router-related-2206419-sw9lg.1762931643875.169&amp;#x26;vd_source=52455a50a39ab9ee183496a6de048a09&quot;&gt;UP主&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;&lt;code&gt;KV Cache&lt;/code&gt;（&lt;code&gt;Key-Value Cache&lt;/code&gt;，键值缓存）是大语言模型（&lt;code&gt;LLM&lt;/code&gt;）在自回归生成过程中用于加速推理、避免重复计算的一种关键技术。在 &lt;code&gt;Transformer&lt;/code&gt; 的解码阶段，模型每次生成一个新 &lt;code&gt;token&lt;/code&gt; 时，都需要计算从第一个 &lt;code&gt;token&lt;/code&gt; 到当前所有 &lt;code&gt;token&lt;/code&gt; 的自注意力（&lt;code&gt;self-attention&lt;/code&gt;）。如果不做优化，每生成一个新词，就要重新计算之前所有 &lt;code&gt;token&lt;/code&gt; 的 &lt;code&gt;Key（K）&lt;/code&gt;和 &lt;code&gt;Value（V）&lt;/code&gt;向量，导致时间和内存开销随生成长度平方级增长。&lt;code&gt;KV Cache&lt;/code&gt; 的核心思想是：将已生成 &lt;code&gt;token&lt;/code&gt; 对应的 &lt;code&gt;K&lt;/code&gt; 和 &lt;code&gt;V&lt;/code&gt; 向量缓存起来，在后续生成中直接复用，只需计算新 &lt;code&gt;token&lt;/code&gt; 的 &lt;code&gt;K、V&lt;/code&gt; 并拼接到缓存中。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;KV Cache&lt;/code&gt;应用于推理阶段（也就是&lt;code&gt;K、V&lt;/code&gt;的值是不变的）&lt;/li&gt;
&lt;li&gt;&lt;code&gt;KV Cache&lt;/code&gt;只存在于&lt;code&gt;Decoder&lt;/code&gt;解码器中&lt;/li&gt;
&lt;li&gt;它的目的是为了加速&lt;code&gt;Q@K@V&lt;/code&gt;的两次矩阵相乘时的速度&lt;/li&gt;
&lt;li&gt;&lt;code&gt;KV Cache&lt;/code&gt;会加大内存占用&lt;/li&gt;
&lt;/ul&gt;
&lt;h1&gt;自注意力机制&lt;/h1&gt;
&lt;p&gt;首先回顾一下&lt;code&gt;Attention&lt;/code&gt;注意力机制：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;code&gt;Q（Query）、K（Key）、V（Value）&lt;/code&gt; 并不是凭空产生的，而是通过对输入序列中的每个 &lt;code&gt;token&lt;/code&gt; 的嵌入表示（&lt;code&gt;embedding&lt;/code&gt;）分别进行线性变换（即矩阵乘法）得到的。&lt;/p&gt;
&lt;p&gt;对于&lt;code&gt;Q&lt;/code&gt;的维度可以理解为，&lt;code&gt;seq_length&lt;/code&gt;代表输入&lt;code&gt;token&lt;/code&gt;序列的长度，&lt;code&gt;d_model&lt;/code&gt;代表每个&lt;code&gt;token&lt;/code&gt;的维度：

经过$Q@K^T$计算后，会得到一个(&lt;code&gt;seq_length*seq_length&lt;/code&gt;)大小的矩阵，参数量取决于输入&lt;code&gt;token&lt;/code&gt;序列的长度。该矩阵的每个元素表示序列中某一个 &lt;code&gt;token&lt;/code&gt; 的 &lt;code&gt;Query&lt;/code&gt; 与另一个 &lt;code&gt;token&lt;/code&gt; 的 &lt;code&gt;Key&lt;/code&gt; 之间的相似度（或相关性得分）。这个矩阵被称为&lt;strong&gt;注意力分数矩阵&lt;/strong&gt;（&lt;code&gt;Attention Scores&lt;/code&gt;）。&lt;/p&gt;
&lt;p&gt;由于 &lt;code&gt;Q&lt;/code&gt; 和 &lt;code&gt;K&lt;/code&gt; 的维度通常较高，它们的点积结果方差会较大，导致 &lt;code&gt;softmax&lt;/code&gt; 函数进入梯度极小的饱和区。  因此，&lt;code&gt;Transformer&lt;/code&gt; 引入了&lt;strong&gt;缩放因子&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;在语言模型生成任务中，为&lt;strong&gt;防止当前位置“看到未来信息”&lt;/strong&gt;，会对注意力分数矩阵应用&lt;strong&gt;因果掩码（causal mask）&lt;/strong&gt;：将上三角部分设为负无穷（-infty），这样在 &lt;code&gt;softmax&lt;/code&gt; 后对应位置权重为 &lt;code&gt;0&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;接下来，对每一行应用 &lt;code&gt;softmax&lt;/code&gt;，将注意力分数转换为&lt;strong&gt;概率分布&lt;/strong&gt;（即每个 token 对其他所有 &lt;code&gt;token&lt;/code&gt; 的注意力权重）：&lt;/p&gt;
&lt;p&gt;最后，用&lt;strong&gt;注意力分数矩阵&lt;/strong&gt;对 &lt;code&gt;V&lt;/code&gt; 矩阵进行加权求和（即矩阵乘法）：得到每个 &lt;code&gt;token&lt;/code&gt; 融合了上下文信息的新表示。&lt;/p&gt;
&lt;h1&gt;模型推理过程&lt;/h1&gt;
&lt;p&gt;模型推理过程如图：模型根据用户的输入然后不断地预测下一个字，如此循环往复，直到遇到推理结束符。&lt;/p&gt;
&lt;p&gt;再进一步看推理过程：
&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
当我们输入&lt;strong&gt;中华人民&lt;/strong&gt;时，首先通过$Q@K^T$得到&lt;strong&gt;注意力分数矩阵&lt;/strong&gt;，再经过缩放和掩码操作后，与&lt;code&gt;V&lt;/code&gt;相乘，最终得到预测的&lt;code&gt;token&lt;/code&gt;序列。&lt;/p&gt;
&lt;p&gt;但是会存在一个问题:&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;5.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
就是使用&lt;strong&gt;中华人&lt;/strong&gt;推出下一个词是&lt;strong&gt;民&lt;/strong&gt;，而下一次推理要使用&lt;strong&gt;中华人民&lt;/strong&gt;去推理出&lt;strong&gt;共&lt;/strong&gt;，这样会存在重复计算。每次都需要计算一次&lt;strong&gt;中华人&lt;/strong&gt;（$Q$）与&lt;strong&gt;中华人&lt;/strong&gt;（$K^T$），并且它们的计算结果都一样。不一样的是，每次计算的$Q@K^T$都会多一个维度（因为每次推理多一个&lt;code&gt;token&lt;/code&gt;）。&lt;/p&gt;
&lt;p&gt;由此，我们可以选择使用空间换时间的方式，保存这些不变的信息（$K$、$V$以及上一轮计算的$Q@K^T$作为缓存，供下一次推力使用），从而实现推理加速，因此引入&lt;code&gt;KV Cache&lt;/code&gt;。
&lt;em&gt;&lt;strong&gt;tips&lt;/strong&gt;&lt;/em&gt;：&lt;code&gt;Q&lt;/code&gt;是每次需要结合上一次预测的&lt;code&gt;token&lt;/code&gt;作为输入，是动态的，所以不能保存。&lt;/p&gt;
&lt;h1&gt;KV Cache&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;6.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
由图可知，每开始新的一轮推理时，不再需要前面的&lt;code&gt;Q&lt;/code&gt;与预测的&lt;code&gt;token&lt;/code&gt;值结合作为新的&lt;code&gt;Q&lt;/code&gt;进行推理，而只需要将上一轮预测的&lt;code&gt;token&lt;/code&gt;作为&lt;code&gt;Q&lt;/code&gt;即可，大大减少了计算量。&lt;/p&gt;
&lt;h1&gt;KV缓存示例&lt;/h1&gt;
&lt;p&gt;以多轮对话为例子：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;7.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
当输入第一个问题&lt;strong&gt;Turn 1(Q)&lt;/strong&gt;，模型给出答案&lt;strong&gt;Turn 1(A)&lt;/strong&gt;，然后这轮对话（&lt;strong&gt;Q+A&lt;/strong&gt;）会与下一轮问题&lt;strong&gt;Turn 2(Q)&lt;/strong&gt;，喂入至模型中得到答案&lt;strong&gt;Turn 2(A)&lt;/strong&gt;...&lt;/p&gt;
&lt;p&gt;在多轮对话（chat）场景中，每轮对话都可以复用前几轮的 KV Cache，从而实现显著加速和内存优化。这正是现代大语言模型（如 ChatGPT、Qwen、Llama 等）在实际推理中使用 KV Cache 的核心价值所在。&lt;/p&gt;
&lt;h1&gt;测试速度&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;8.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（八）大模型微调之DPO训练</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-8/llm_blogs-8</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-8/llm_blogs-8</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 16:05:00 GMT</pubDate><content:encoded>&lt;p&gt;DPO（直接偏好优化）​ 是一种无需训练奖励模型的强化学习算法，专门用于对齐大语言模型与人类偏好。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;核心思想：直接比较模型对&quot;好回答&quot;和&quot;坏回答&quot;的偏好，通过优化损失函数让模型更喜欢生成人类偏好的回答。&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;首先定义一个创建qwen模型方法：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def create_qwen_model():
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        torch_dtype=&quot;auto&quot;,
        device_map=&quot;auto&quot;
    )
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    return model,tokenizer
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后定义两个qwen模型，一个用于训练，一个用于参照：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# DPO训练的模型
model_pi,tokenizer=create_qwen_model()
# DPO参照的模型
model_ref,_=create_qwen_model()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中学习模型是被训练优化的模型，学习生成更好的回答。参考模型是保持初始行为，防止模型&quot;遗忘&quot;或&quot;跑偏&quot;&lt;/p&gt;
&lt;p&gt;再定义一个&lt;code&gt;chat&lt;/code&gt;方法，跟之前的一样，只是多了&lt;code&gt;model&lt;/code&gt;和&lt;code&gt;tokenizer&lt;/code&gt;两个参数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 模型测试方法
def chat(prompt,tokenizer,model):
    messages = [
        {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
        {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: prompt},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    #print(text)

    model_inputs = tokenizer([text], return_tensors=&quot;pt&quot;).to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;定义DPO训练数据集：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;dpo_train_data=[
    {&apos;prompt&apos;:&apos;你是谁?&apos;,&apos;chosen&apos;:&apos;通义千问&apos;,&apos;reject&apos;:&apos;我是阿里云开发的超大规模语言模型，我叫通义千问。&apos;},
    {&apos;prompt&apos;:&apos;你是谁发明的?&apos;,&apos;chosen&apos;:&apos;扣你jio哇&apos;,&apos;reject&apos;:&apos;阿里巴巴&apos;},
]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后将DPO偏好数据转换为对话格式：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 偏好数据集 -&gt; 模型输入
def dpo_to_messages(dpo_pairs):
    chosen_messages=[]
    reject_messages=[]
    for pair in dpo_pairs:
        chosen_messages.append([
                {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
                {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: pair[&apos;prompt&apos;]},
                {&quot;role&quot;: &quot;assistant&quot;, &quot;content&quot;: pair[&apos;chosen&apos;]},
            ]
        )
        reject_messages.append([
                {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
                {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: pair[&apos;prompt&apos;]},
                {&quot;role&quot;: &quot;assistant&quot;, &quot;content&quot;: pair[&apos;reject&apos;]},
            ]
        )
    return chosen_messages,reject_messages
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;训练数据预处理与之前一样的：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 训练数据预处理
def preprocess(tokenizer,batch_messages):
    input_list=[]
    target_list=[]
    
    im_start=tokenizer(&apos;&amp;#x3C;|im_start|&gt;&apos;).input_ids
    im_end=tokenizer(&apos;&amp;#x3C;|im_end|&gt;&apos;).input_ids
    newline=tokenizer(&apos;\n&apos;).input_ids
    pad=tokenizer(&apos;&amp;#x3C;|endoftext|&gt;&apos;).input_ids
    ignore=[-100]
    
    for group in batch_messages:
        input_ids=[]
        target_ids=[]
        for msg in group:
            role=tokenizer(msg[&apos;role&apos;]).input_ids
            content=tokenizer(msg[&apos;content&apos;]).input_ids
            if msg[&apos;role&apos;] in [&apos;system&apos;,&apos;user&apos;]:
                ignore_parts=role+newline+content
                input_ids+=im_start+ignore_parts+im_end+newline
                target_ids+=im_start+ignore*len(ignore_parts)+im_end+newline
            else:
                ignore_parts=role+newline
                input_ids+=im_start+ignore_parts+content+im_end+newline
                target_ids+=im_start+ignore*len(ignore_parts)+content+im_end+newline
        input_list.append(input_ids)
        target_list.append(target_ids)
    
    # padding
    max_len=max([len(ids) for ids in input_list])
    for input_ids,target_ids in zip(input_list,target_list):
        input_ids+=pad*(max_len-len(input_ids))
        target_ids+=ignore*(max_len-len(target_ids))
    batch_input_ids=torch.tensor(input_list,dtype=torch.long)
    batch_target_ids=torch.tensor(target_list,dtype=torch.long)
    batch_mask=batch_input_ids.ne(pad[0]).type(torch.long)
    return batch_input_ids,batch_target_ids,batch_mask
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;模型设置为&lt;code&gt;train&lt;/code&gt;模式：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model_pi.train()
model_ref.train()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;看下结构：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear(in_features=896, out_features=151936, bias=False)
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;优化器设置：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 优化器，只训练pi模型
optimizer=torch.optim.SGD(model_pi.parameters(),lr=1e-3)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来就是DPO的重点，损失函数的设计：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# DPO损失计算-辅助函数
def dpo_prob_calc(target_ids,pi_logits,ref_logits):
    pi_probs=torch.log_softmax(pi_logits,dim=-1)      # softmax概率+log对数
    ref_probs=torch.log_softmax(ref_logits,dim=-1)
    
    ignore_mask=target_ids!=-100 # ignore token掩码
    indexes=target_ids*ignore_mask # 将-100变成0，以便后面gather可以运行
    
    pi_probs_of_target=torch.gather(pi_probs,dim=-1,index=indexes.unsqueeze(-1)).squeeze(-1) * ignore_mask # 取目标target token的概率，忽略-100 token
    ref_probs_of_target=torch.gather(ref_probs,dim=-1,index=indexes.unsqueeze(-1)).squeeze(-1) * ignore_mask    
    
    pi_final_prob=pi_probs_of_target.sum(-1)/ignore_mask.sum(-1)     # 求每一个样本的token prob均值
    ref_final_prob=ref_probs_of_target.sum(-1)/ignore_mask.sum(-1)
    return pi_final_prob,ref_final_prob
    
# DPO损失函数 https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py
def dpo_loss(params):
    ## 两个模型的chosen输出
    chosen_target_ids=params[&apos;chosen_target_ids&apos;][:,1:]
    pi_chosen_logits=params[&apos;pi_chosen_logits&apos;][:,:-1,:]
    ref_chosen_logits=params[&apos;ref_chosen_logits&apos;][:,:-1,:]
    pi_chosen_prob,ref_chosen_prob=dpo_prob_calc(chosen_target_ids,pi_chosen_logits,ref_chosen_logits)
    
    ## 两个模型的reject输出
    reject_target_ids=params[&apos;reject_target_ids&apos;][:,1:]
    pi_reject_logits=params[&apos;pi_reject_logits&apos;][:,:-1,:]
    ref_reject_logits=params[&apos;ref_reject_logits&apos;][:,:-1,:]
    pi_reject_prob,ref_reject_prob=dpo_prob_calc(reject_target_ids,pi_reject_logits,ref_reject_logits)
    
    # 计算DPO Loss
    pi_prob_diff=pi_chosen_prob-pi_reject_prob 
    ref_prob_diff=ref_chosen_prob-ref_reject_prob
    beta=0.1
    loss=-torch.nn.functional.logsigmoid(beta*(pi_prob_diff-ref_prob_diff))
    return loss.mean()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;DPO的损失函数的公式长这样：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
接下来开始训练：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;iterators=20

vocab=tokenizer.get_vocab()
# print(vocab)
for i in range(iterators):
    # 一批模拟数据
    chosen_messages,reject_messages=dpo_to_messages(dpo_train_data)
    # model输入和输出
    chosen_input_ids,chosen_target_ids,chosen_mask=preprocess(tokenizer,chosen_messages)
    reject_input_ids,reject_target_ids,reject_mask=preprocess(tokenizer,reject_messages)
    # model_pi预测
    pi_chosen_logits=model_pi(input_ids=chosen_input_ids.to(device),attention_mask=chosen_mask.to(device)).logits
    pi_reject_logits=model_pi(input_ids=reject_input_ids.to(device),attention_mask=reject_mask.to(device)).logits
    # model_ref预测
    ref_chosen_logits=model_ref(chosen_input_ids.to(device),chosen_mask.to(device)).logits
    ref_reject_logits=model_ref(reject_input_ids.to(device),reject_mask.to(device)).logits
    # 求DPO损失
    loss=dpo_loss({
        &apos;chosen_target_ids&apos;:chosen_target_ids.to(device),
        &apos;reject_target_ids&apos;:reject_target_ids.to(device),
        &apos;pi_chosen_logits&apos;:pi_chosen_logits.to(device),
        &apos;pi_reject_logits&apos;:pi_reject_logits.to(device),
        &apos;ref_chosen_logits&apos;:ref_chosen_logits.to(device),
        &apos;ref_reject_logits&apos;:ref_reject_logits.to(device),
    })
    print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;训练的过程中只更新pi模型的权重。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model_pi.eval()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将pi模型设置为&lt;code&gt;eval&lt;/code&gt;模式。&lt;/p&gt;
&lt;p&gt;训练20个epoch后的效果：
&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
感觉还没有拟合，第二个问题回答不对。继续训练30个epoch:
&lt;img src=&quot;3.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
现在OK了，经过DPO微调之后模型原有的知识也不会发生大的变动：
&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（七）大模型微调之SFT训练</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-7/llm_blogs-7</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-7/llm_blogs-7</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 16:00:00 GMT</pubDate><content:encoded>&lt;p&gt;之前的[大模型学习（一）通义千问1.8B大模型微调]是基于通义千问模型提供的微调代码进行微调的，现在是手写一个微调的代码，相对来说更底层，感受一下大模型微调。&lt;/p&gt;
&lt;p&gt;下载大模型到本地：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;from transformers import AutoModelForCausalLM, AutoTokenizer
from modelscope import snapshot_download
import torch

device = &quot;cuda&quot; # the device to load the model onto

model_dir = snapshot_download(&apos;Qwen/Qwen2-0.5B-Instruct&apos;, cache_dir=&quot;./Models&quot;)

model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    torch_dtype=&quot;auto&quot;,
    device_map=&quot;auto&quot;
)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;打印分词器tokenizer：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;print(tokenizer)
# Qwen2TokenizerFast(name_or_path=&apos;./Models/Qwen/Qwen2-0___5B-Instruct&apos;, vocab_size=151643, model_max_length=32768, is_fast=True, padding_side=&apos;right&apos;, truncation_side=&apos;right&apos;, special_tokens={&apos;eos_token&apos;: &apos;&amp;#x3C;|im_end|&gt;&apos;, &apos;pad_token&apos;: &apos;&amp;#x3C;|endoftext|&gt;&apos;, &apos;additional_special_tokens&apos;: [&apos;&amp;#x3C;|im_start|&gt;&apos;, &apos;&amp;#x3C;|im_end|&gt;&apos;]}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
#	151643: AddedToken(&quot;&amp;#x3C;|endoftext|&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
#	151644: AddedToken(&quot;&amp;#x3C;|im_start|&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
#	151645: AddedToken(&quot;&amp;#x3C;|im_end|&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
# }
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中&lt;code&gt;vocab_size=151643&lt;/code&gt;表示所有可能的 token（包括字、子词、符号、特殊标记等）都被映射到 0 ~ 151642 的整数 ID。&lt;code&gt;model_max_length=32768&lt;/code&gt;表示模型支持的最大上下文长度是 32,768 个 tokens，超长会被截断。还定义了一些特殊的token。&lt;/p&gt;
&lt;p&gt;接下来定义一个对话大模型的chat方法：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def chat(prompt):
    messages = [
        {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
        {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: prompt},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    print(text)

    model_inputs = tokenizer([text], return_tensors=&quot;pt&quot;).to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;messages = [
        {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
        {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: prompt},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码是将对话格式转换为模型期望的ChatML文本格式。&lt;code&gt;tokenize=False&lt;/code&gt;表示返回字符串而不是token IDs，&lt;code&gt;add_generation_prompt=True&lt;/code&gt;表示添加生成提示符，告诉模型开始生成回复。当&lt;code&gt;add_generation_prompt=True&lt;/code&gt;时：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
你是谁？&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;assistant
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;当&lt;code&gt;add_generation_prompt=False&lt;/code&gt;时：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
你是谁？&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来就是将&lt;code&gt;prompt&lt;/code&gt;进行&lt;code&gt;tokenizer&lt;/code&gt;操作：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model_inputs = tokenizer([text], return_tensors=&quot;pt&quot;).to(device)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将文本转换为模型可处理的张量格式。返回的&lt;code&gt;model_inputs&lt;/code&gt;内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;{&apos;input_ids&apos;: tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
         151645,    198, 151644,    872,    198, 105043, 100165,  11319, 151645,
            198, 151644,  77091,    198]], device=&apos;cuda:0&apos;), &apos;attention_mask&apos;: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device=&apos;cuda:0&apos;)}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它包括&lt;code&gt;input_ids&lt;/code&gt;和&lt;code&gt;attention_mask&lt;/code&gt;两个内容。注意力掩码（Attention Mask）​ 是Transformer模型中非常重要的组件，用于控制模型在处理序列时应该关注哪些位置。它是一个二进制张量，告诉模型：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;1：需要关注这个位置（真实token）&lt;/li&gt;
&lt;li&gt;0：忽略这个位置（填充token）&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;举个例子：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 两个原始句子
sentences = [
    &quot;Hello world&quot;,           # 长度2
    &quot;This is a test sentence&quot; # 长度5
]

# 经过tokenization和padding后：
input_ids = [
    [101, 102, 0, 0, 0],    # &quot;Hello world&quot; + 3个填充
    [201, 202, 203, 204, 205] # &quot;This is a test sentence&quot;
]

# 对应的注意力掩码：
attention_mask = [
    [1, 1, 0, 0, 0],        # 前2个是真实token，后3个是填充
    [1, 1, 1, 1, 1]         # 全部是真实token
]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来调用模型生成回答：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;生成的&lt;code&gt;generated_ids&lt;/code&gt; 也是&lt;code&gt;token&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;tensor([[151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
         151645,    198, 151644,    872,    198, 105043, 100165,  11319, 151645,
            198, 151644,  77091,    198, 104198, 101919, 102661,  99718, 104197,
         100176, 102064, 104949,   3837,  35946,  99882,  31935,  64559,  99320,
          56007,   1773, 151645]], device=&apos;cuda:0&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再提取大模型新生成的内容：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;从完整输出中提取新生成的部分（去除输入部分）。&lt;/p&gt;
&lt;p&gt;最后解码回复：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这样就构建好了一个&lt;code&gt;Chat&lt;/code&gt;方法，从输入&lt;code&gt;Prompt&lt;/code&gt;到大模型的&lt;code&gt;response&lt;/code&gt;。可以直接调用：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
接下来是训练数据预处理方法，这里是&lt;strong&gt;重中之重&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 训练数据预处理方法
def preprocess(tokenizer,batch_messages):
    input_list=[]
    target_list=[]
    
    im_start=tokenizer(&apos;&amp;#x3C;|im_start|&gt;&apos;).input_ids
    im_end=tokenizer(&apos;&amp;#x3C;|im_end|&gt;&apos;).input_ids
    newline=tokenizer(&apos;\n&apos;).input_ids
    pad=tokenizer(&apos;&amp;#x3C;|endoftext|&gt;&apos;).input_ids
    ignore=[-100]
    
    for group in batch_messages:
        input_ids=[]
        target_ids=[]
        for msg in group:
            role=tokenizer(msg[&apos;role&apos;]).input_ids
            content=tokenizer(msg[&apos;content&apos;]).input_ids
            if msg[&apos;role&apos;] in [&apos;system&apos;,&apos;user&apos;]:
                ignore_parts=role+newline+content
                input_ids+=im_start+ignore_parts+im_end+newline
                target_ids+=im_start+ignore*len(ignore_parts)+im_end+newline
            else:
                ignore_parts=role+newline
                input_ids+=im_start+ignore_parts+content+im_end+newline
                target_ids+=im_start+ignore*len(ignore_parts)+content+im_end+newline
        input_list.append(input_ids)
        target_list.append(target_ids)
    
    # padding
    max_len=max([len(ids) for ids in input_list])
    for input_ids,target_ids in zip(input_list,target_list):
        input_ids+=pad*(max_len-len(input_ids))
        target_ids+=ignore*(max_len-len(target_ids))
    batch_input_ids=torch.tensor(input_list,dtype=torch.long)
    batch_target_ids=torch.tensor(target_list,dtype=torch.long)
    batch_mask=batch_input_ids.ne(pad[0]).type(torch.long)
    return batch_input_ids,batch_target_ids,batch_mask
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码是专门服务于Qwen模型的SFT（监督微调）训练的它的输入是&lt;code&gt;batch_messages&lt;/code&gt;- 批量对话数据，输出是&lt;code&gt;(input_ids, target_ids, attention_mask)&lt;/code&gt;- 训练三元组。它的核心思想就是只预测大模型的输出内容，这个效果是通过&lt;code&gt;ignore=[-100]&lt;/code&gt;实现的，具体处理规则是：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if role in [&apos;system&apos;, &apos;user&apos;]:
    # 输入：完整消息，目标：全部忽略
    input_ids += [完整token序列]
    target_ids += [-100, -100, ...]  # 忽略loss计算
    
else:  # assistant
    # 输入：角色标记+内容，目标：忽略角色标记，预测内容
    input_ids += [角色标记 + 内容]
    target_ids += [-100, -100, ... + 内容token]  # 只计算内容部分的loss
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;因为在&lt;code&gt;CrossEntropyLoss&lt;/code&gt;函数中，定义了&lt;code&gt;ignore_index=-100&lt;/code&gt;部分不会参与损失的计算：
&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;strong&gt;举个例子&lt;/strong&gt;：&lt;/p&gt;
&lt;p&gt;输入对话数据：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;batch_messages = [
    [  # 第一个对话
        {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;你是有用助手&quot;},
        {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;你好&quot;}, 
        {&quot;role&quot;: &quot;assistant&quot;, &quot;content&quot;: &quot;你好！需要什么帮助？&quot;}
    ]
]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;处理后的掩码效果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;输入Tokens:  [START, system, \n, 你, 有, 用, 助手, END, \n, START, user, \n, 你, 好, END, \n, START, assistant, \n, 你, 好, ！, 需, 要, 什, 么, 帮, 助, ？, END, \n]
目标Labels:  [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 你, 好, ！, 需, 要, 什, 么, 帮, 助, ？, -100, -100]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;也就实现了只有大模型输出的内容参与了计算，其他&lt;code&gt;-100&lt;/code&gt;的区域都不参与计算。&lt;/p&gt;
&lt;p&gt;接下来训练数据进行测试：&lt;/p&gt;
&lt;p&gt;构建一组训练样本数据：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;prompt = &quot;2+2等于几&quot;
messages = [
    [
        {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
        {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: prompt},
        {&quot;role&quot;: &quot;assistant&quot;, &quot;content&quot;: &apos;2+2等于5。&apos;},
    ],
    [
        {&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: &quot;You are a helpful assistant.&quot;},
        {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: prompt},
        {&quot;role&quot;: &quot;assistant&quot;, &quot;content&quot;: &apos;2+2等于5。&apos;},
    ]
]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后设置模型为train模式，再将&lt;code&gt;message batch&lt;/code&gt;经过&lt;code&gt;tokenizer&lt;/code&gt;后输入到模型进行推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model.train()

batch_input_ids,batch_target_ids,batch_mask=preprocess(tokenizer,messages)
model_outputs=model(batch_input_ids.to(device))
output_tokens=model_outputs.logits.argmax(dim=-1)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里&lt;code&gt;model_outputs.logits&lt;/code&gt;的&lt;code&gt;shape&lt;/code&gt;为&lt;code&gt;[batch_size, seq_length, vocab_size]&lt;/code&gt;，其中&lt;code&gt;vocab_size&lt;/code&gt;与&lt;code&gt;tokenizer&lt;/code&gt;中的&lt;code&gt;vocab_size=151643&lt;/code&gt;是一样的。&lt;code&gt;argmax(dim=-1)&lt;/code&gt; 在最后一个维度（词汇表维度）取最大值索引，表示取在词汇表中最大可能出现的词。&lt;/p&gt;
&lt;p&gt;Qwen 系列基于 Transformer 架构 ，主要采用 解码器-only（Decoder-only） 的因果语言模型，因此需要进行错位对齐，举个例子：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;输入序列: [A, B, C, D, E]
预测目标: [B, C, D, E, F]

模型任务：给定前N个token，预测第N+1个token
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;代码实现：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 原始输出
logits_full = model_outputs.logits        # [batch, seq_len, vocab_size]
targets_full = batch_target_ids          # [batch, seq_len]

# 错位对齐
logits = logits_full[:, :-1, :]     # 去掉最后一个位置的预测，最后一个位置是&amp;#x3C;end&gt;
targets = targets_full[:, 1:]       # 去掉第一个位置的标签
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;最后进行损失计算：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;from torch.nn import CrossEntropyLoss

# 损失
loss_fn=CrossEntropyLoss()
loss=loss_fn(logits.reshape(-1,logits.size(2)),targets.reshape(-1))
print(&apos;loss:&apos;,loss)

# 优化器
optimizer=torch.optim.SGD(model.parameters())
optimizer.zero_grad()

# 求梯度
loss.backward()

# 梯度下降
optimizer.step()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中&lt;code&gt;logits.reshape(-1,logits.size(2))&lt;/code&gt;处理后的为&lt;code&gt;[batch*(seq_len-1), vocab_size]&lt;/code&gt;，表示&lt;code&gt;batch*(seq_len-1)&lt;/code&gt;个位置需要预测，每个位置&lt;code&gt;vocab_size&lt;/code&gt;种可能，&lt;code&gt;targets.reshape(-1)&lt;/code&gt;操作处理后为&lt;code&gt;batch*(seq_len-1)&lt;/code&gt;，表示&lt;code&gt;batch*(seq_len-1)&lt;/code&gt;对应的真实标签。&lt;/p&gt;
&lt;p&gt;测试一下新的模型效果：
&lt;img src=&quot;3.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
发现微调效果可以。&lt;/p&gt;
&lt;p&gt;需要注意的是，训练的epoch数不能太多，否则会过拟合，就会出现下面的效果：
&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;img src=&quot;5.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;img src=&quot;6.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（六）大模型微调之理论篇</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-6</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-6</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 15:55:00 GMT</pubDate><content:encoded>&lt;p&gt;一般来说，大模型的训练分为三个部分：（1）预训练。（2）监督式微调。（3）强化学习。&lt;/p&gt;
&lt;h1&gt;预训练&lt;/h1&gt;
&lt;p&gt;预训练使用的是海量的&lt;strong&gt;无标注&lt;/strong&gt;的语料，主要是训练模型的&lt;strong&gt;词语接龙能力&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;训练的语料比如：扣你jio哇是一个CSDN的小博主。&lt;/p&gt;
&lt;p&gt;当给模型&lt;code&gt;扣你jio哇是一个CSDN的小&lt;/code&gt;这段文字后，它会自己接上&lt;code&gt;博主&lt;/code&gt;两个词。&lt;/p&gt;
&lt;h1&gt;监督式微调SFT&lt;/h1&gt;
&lt;p&gt;SFT主要是锻炼模型的一个&lt;strong&gt;问答能力&lt;/strong&gt;。它的训练语料是&lt;strong&gt;有标注&lt;/strong&gt;的优质语料。模型在这个过程中主要是&lt;strong&gt;学习对格式的理解&lt;/strong&gt;。比如&lt;code&gt;&amp;#x3C;|im_start|&gt;user\n****&amp;#x3C;|im_end|&gt;\n&lt;/code&gt;这段内容是用户说的。模型推理的阶段是从&lt;code&gt;&amp;#x3C;|im_start|&gt;assistant\n&lt;/code&gt;开始，然后到&lt;code&gt;&amp;#x3C;|im_end|&gt;&lt;/code&gt;结束。&lt;/p&gt;
&lt;p&gt;训练语料的格式：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;user\n扣你jio哇是谁？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n扣你jio哇是一个CSDN的小博主。&amp;#x3C;|im_end|&gt;\n
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;推理阶段：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;user\n扣你jio哇是谁？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;大模型返回：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;扣你jio哇是一个CSDN的小博主。&amp;#x3C;|im_end|&gt;\n
&lt;/code&gt;&lt;/pre&gt;
&lt;h1&gt;强化学习RL&lt;/h1&gt;
&lt;p&gt;算法包括&lt;code&gt;PPO&lt;/code&gt;、&lt;code&gt;RLHF&lt;/code&gt;和&lt;code&gt;DPO&lt;/code&gt;等，其中&lt;code&gt;DPO&lt;/code&gt;实现起来最简单。通过&lt;code&gt;DPO&lt;/code&gt;学习算法，可以让大模型在回答的时候进行偏好对齐。&lt;/p&gt;
&lt;p&gt;训练语料的格式：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;chose好的回答：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;user\n扣你jio哇是谁？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n扣你jio哇是一个CSDN的小博主。&amp;#x3C;|im_end|&gt;\n
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;reject坏的回答：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;user\n扣你jio哇是谁？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n我不知道扣你jio哇是什么东西。&amp;#x3C;|im_end|&gt;\n
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;通过强化学习算法，可以让大模型回答的时候，倾向于选择好的回答，尽量不用坏的回答。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（五）asyncio+uvicorn+fastapi+threadpool</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-5</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-5</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 15:50:00 GMT</pubDate><content:encoded>&lt;p&gt;简单介绍一下各个工具的作用：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;asyncio（异步I/O库）：作用: 提供异步编程基础设施，包括事件循环、协程、任务等。特点: 单线程并发，适用于 I/O 密集型任务&lt;/li&gt;
&lt;li&gt;FastAPI（Web框架）：作用: 提供 REST API 构建功能，原生支持异步视图函数。特点: 自动 API 文档生成，类型提示支持&lt;/li&gt;
&lt;li&gt;Uvicorn（ASGI服务器）：作用: 运行 FastAPI 应用，处理 HTTP 请求。特点: 异步、快速，生产环境常用&lt;/li&gt;
&lt;li&gt;ThreadPoolExecutor（线程池）：作用: 管理线程池，执行阻塞任务。特点: 适用于 CPU 密集型或阻塞 I/O 任务&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;工作流程如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Uvicorn 启动并监听端口，接收 HTTP 请求
请求被传递给 FastAPI 应用进行路由处理
FastAPI 在 AsyncIO 事件循环中执行异步视图函数
当遇到阻塞操作时，使用 ThreadPoolExecutor 在线程中执行
结果通过回调机制返回给事件循环
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;创建fastapi和线程池：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 创建web
app=FastAPI()

#   创建线程池
threadpool=ThreadPoolExecutor(max_workers=200)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;处理对ver1的get请求：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;#   第一个版本
@app.get(&apos;/ver1&apos;)
async def ver1(request: Request):
    # 获取参数
    msg=request.query_params.get(&apos;msg&apos;)

    # 在主线程中获取async io event loop
    loop=asyncio.get_event_loop()
    
    # 准备计算任务
    task={
        &apos;msg&apos;: msg,
        &apos;event&apos;: asyncio.Event(),
        &apos;result&apos;: None,
    }
    
    # 计算函数
    def handle_task():
    		 # 在工作线程中
        print(&apos;task received:&apos;,task[&apos;msg&apos;])
        task[&apos;result&apos;]=task[&apos;msg&apos;].lower()
        time.sleep(2) # 模拟线程阻塞
        def async_callback():
            print(&apos;task ends notified:&apos;,task[&apos;result&apos;],asyncio.get_event_loop())
            # 这将在主线程中执行
            task[&apos;event&apos;].set()
        # 安全地将回调调度到主线程的事件循环
        loop.call_soon_threadsafe(async_callback)
    
    # 提交并等待结果
    threadpool.submit(handle_task)
    await task[&apos;event&apos;].wait()
    
    return Response(task[&apos;result&apos;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中获取当前asyncio事件循环&lt;code&gt;loop=asyncio.get_event_loop()&lt;/code&gt;的作用是允许工作线程安全地与主线程的事件循环交互，确保回调函数在正确的线程（事件循环所在的线程）执行。&lt;/p&gt;
&lt;p&gt;工作流程：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;当用户启动应用后，Uvicorn服务器开始监听localhost的8000端口。程序初始化时创建了一个包含200个工作线程的线程池，用于处理耗时的阻塞任务。当接收到HTTP GET请求访问/ver1或/ver2端点时，FastAPI在主线程（事件循环线程）中处理请求。&lt;/li&gt;
&lt;li&gt;对于/ver1端点，程序首先从请求中提取参数，获取当前的asyncio事件循环引用，并创建一个包含消息内容、异步事件和结果占位符的任务对象。然后定义一个在工作线程中执行的handle_task函数，该函数会进行消息处理（如转换为小写）并模拟2秒的阻塞操作。通过threadpool.submit()将任务提交到线程池后，主线程执行await task[&apos;event&apos;].wait()进入非阻塞等待状态，此时事件循环可以继续处理其他请求。工作线程执行完阻塞操作后，通过之前获取的事件循环引用，使用loop.call_soon_threadsafe()将回调函数安全地调度回主线程执行，在回调中触发事件解除主线程等待，最终返回处理结果。
&lt;em&gt;&lt;strong&gt;tips:在代码中，task 信息从主线程传递到工作线程的方式是通过闭包机制实现的。当 handle_task 函数被定义时，它会捕获外部作用域中的 task 变量引用，形成闭包。当 threadpool.submit(handle_task) 执行时，传递的是函数对象及其闭包环境。&lt;/strong&gt;&lt;/em&gt;&lt;/li&gt;
&lt;li&gt;对于/ver2端点，采用了更简洁的方式，通过loop.run_in_executor()直接将任务提交到线程池并await其结果，底层同样实现了异步等待和线程间通信，但代码更加简洁。&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;#   第二个版本
@app.get(&apos;/ver2&apos;)
async def ver2(request: Request):
    # 获取参数
    msg=request.query_params.get(&apos;msg&apos;)

    # 获取async io event loop
    loop=asyncio.get_event_loop()
    
    # 准备计算任务
    task={
        &apos;msg&apos;: msg,
    }
    
    # 计算函数
    def handle_task():
        print(&apos;task received:&apos;,task[&apos;msg&apos;])
        result=task[&apos;msg&apos;].lower()
        time.sleep(2) # 模拟线程阻塞
        return result
    
    # 提交并等待结果
    result=await loop.run_in_executor(threadpool,handle_task)
    return Response(result)
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（四）LangChain实现RAG检索增强</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-4/llm_blogs-4</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-4/llm_blogs-4</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 15:45:00 GMT</pubDate><content:encoded>&lt;p&gt;由于一些依赖库的版本更新太快了，导致出现不兼容的情况，我也是试了好久，最好是跟我一样安装下面的版本：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 大模型核心框架
vllm==0.4.2
vllm-flash-attn==2.5.9
vllm_nccl_cu12==2.18.1.0.4.0
transformers==4.41.0
transformers-stream-generator==0.0.4
langchain==0.1.6
langchain-community==0.0.20
langchain-core==0.1.46
langchain-openai==0.1.7
langgraph==1.0.2
langgraph-sdk==0.2.9
langsmith==0.0.83

# 量化相关
auto_gptq==0.7.1
compressed-tensors==0.8.1
peft==0.12.0
optimum==2.0.0

# 推理优化
xformers==0.0.26.post1
triton==2.3.0
flash-attn==2.5.9

# 模型服务相关
fastapi==0.121.0
uvicorn==0.38.0
openai==1.109.1

# 向量数据库
faiss-cpu==1.12.0

# Embedding相关
sentencepiece==0.2.1
tokenizers==0.19.1
tiktoken==0.6.0

# 模型下载与管理
huggingface-hub==0.36.0
modelscope==1.9.0

# PyTorch生态
torch==2.3.0+cu118
torchaudio==2.3.1+cu121
torchvision==0.18.0

# 其他重要依赖
pydantic==2.12.4
outlines==0.0.34
outlines_core==0.2.11
&lt;/code&gt;&lt;/pre&gt;
&lt;h1&gt;启动vLLM的openai兼容server&lt;/h1&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;export VLLM_USE_MODELSCOPE=True
python -m vllm.entrypoints.openai.api_server --model &apos;../Qwen-vllm/Models/qwen/Qwen-14B-Chat-Int4&apos; --trust-remote-code -q gptq --dtype float16 --gpu-memory-utilization 0.6
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;em&gt;tips：为什么要用 vllm.entrypoints.openai.api_server 启动？和我之前自己写的 FastAPI + vLLM 服务有什么区别？——涉及开发效率、兼容性、标准化和维护成本。核心目的是：让本地大模型“看起来就像 OpenAI API”，从而获得最大兼容性、最低接入成本和最强生态支持。&lt;/em&gt;&lt;/p&gt;
&lt;h1&gt;生成知识向量库&lt;/h1&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 解析PDF，切成chunk片段
pdf_loader=PyPDFLoader(&apos;LLM.pdf&apos;,extract_images=True)   # 使用OCR解析pdf中图片里面的文字
chunks=pdf_loader.load_and_split(text_splitter=RecursiveCharacterTextSplitter(chunk_size=100,chunk_overlap=10))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;使用PyPDFLoader工具解析PDF文件，extract_images为True表示会使用OCR解析图片文件中的文字，然后设置分块大小为100，重叠区域为10。
输出内容如下：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
一共432段文字，chunks列表里面存放的就是Document类数据，它包括page_content（也就是我们需要的文字），和meta信息，例如&lt;code&gt;metadata={&apos;source&apos;: &apos;LLM.pdf&apos;, &apos;page&apos;: 0}&lt;/code&gt;包括数据来源以及页码等等信息。&lt;/p&gt;
&lt;p&gt;接下来加载embedding模型，用于将chunk向量化，这里通过modelscope加载一些免费的model进行向量化：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 加载embedding模型，用于将chunk向量化
embeddings=ModelScopeEmbeddings(model_id=&apos;iic/nlp_corom_sentence-embedding_chinese-base&apos;) 
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里用的通用领域的，可以根据自己的需要进行选择：
&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
最后保存到脸书的faiss本地向量数据库中：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 将chunk插入到faiss本地向量数据库 
vector_db=FAISS.from_documents(chunks,embeddings)
vector_db.save_local(&apos;LLM.faiss&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;也可以用其他的，在&lt;a href=&quot;https://docs.langchain.com/oss/python/integrations/vectorstores&quot;&gt;langchain&lt;/a&gt;里面有很多：
&lt;img src=&quot;3.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
使用这些集成的工具可以帮助我们快速构建RAG，在具体的项目中这些工具可能会限制你，所以可以考虑自己实现底层逻辑。&lt;/p&gt;
&lt;h1&gt;构建RAG&lt;/h1&gt;
&lt;p&gt;首先加载同一个embedding模型，用于将Query向量化：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;embeddings=ModelScopeEmbeddings(model_id=&apos;iic/nlp_corom_sentence-embedding_chinese-base&apos;) 
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后加载本地faiss向量库，用于知识召回：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;vector_db=FAISS.load_local(&apos;LLM.faiss&apos;,embeddings)
retriever=vector_db.as_retriever(search_kwargs={&quot;k&quot;:5})
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码先加载向量库，然后将向量库作为一个检索器。这里指定top_k为5，返回5个最相关的向量。&lt;/p&gt;
&lt;p&gt;用vllm部署openai兼容的服务端接口，然后走ChatOpenAI客户端调用&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;os.environ[&apos;VLLM_USE_MODELSCOPE&apos;]=&apos;True&apos;
chat=ChatOpenAI(
    model=&quot;qwen/Qwen-7B-Chat-Int4&quot;,
    openai_api_key=&quot;EMPTY&quot;,
    openai_api_base=&apos;http://localhost:8000/v1&apos;,
    stop=[&apos;&amp;#x3C;|im_end|&gt;&apos;]
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再使用langchain提供的提示词模板，非常方便，就不需要我们像之前一样自己写了：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;system_prompt=SystemMessagePromptTemplate.from_template(&apos;You are a helpful assistant.&apos;)
user_prompt=HumanMessagePromptTemplate.from_template(&apos;&apos;&apos;
Answer the question based only on the following context:

{context}

Question: {query}
&apos;&apos;&apos;)
full_chat_prompt=ChatPromptTemplate.from_messages([system_prompt,MessagesPlaceholder(variable_name=&quot;chat_history&quot;),user_prompt])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到最终的full_chat_prompt由三部分组成，头部+腰部+尾部，完全符合之前的结构，它构成的内容也跟之前的ChatML结构是一样的：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.
&amp;#x3C;|im_end|&gt;
...
&amp;#x3C;|im_start|&gt;user
Answer the question based only on the following context:

{context}

Question: {query}
&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;assitant
......
&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后构建Chat chain&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;chat_chain={
        &quot;context&quot;: itemgetter(&quot;query&quot;) | retriever,
        &quot;query&quot;: itemgetter(&quot;query&quot;),
        &quot;chat_history&quot;:itemgetter(&quot;chat_history&quot;),
    }|full_chat_prompt|chat
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它的逻辑就是首先将query传入到检索器中，检索到5个最相关的向量作为context，然后依次传入query和chat_history，再将这三个内容传入到full_chat_prompt，构建起一个完整的Prompt，最后输入到我们的chat模型得到输出。&lt;/p&gt;
&lt;p&gt;最后进行对话：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 开始对话
chat_history=[]
while True:
    query=input(&apos;query:&apos;)
    response=chat_chain.invoke({&apos;query&apos;:query,&apos;chat_history&apos;:chat_history})
    chat_history.extend((HumanMessage(content=query),response))
    print(response.content)
    chat_history=chat_history[-20:] # 最新10轮对话
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;查看效果：
&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
我打印了相关的信息：
&lt;img src=&quot;5.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
可以看到，检索器能够检索到query问题相关的信息并召回当做context，由于是第一次检索，所以chat_history是空的，然后生成的Prompt指令，最后输入到大模型得到答复。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（三）大模型vLLM推理实践</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-3/llm_blogs-3</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-3/llm_blogs-3</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 15:40:00 GMT</pubDate><content:encoded>&lt;p&gt;首先查看Qwen的&lt;a href=&quot;https://github.com/QwenLM/Qwen/blob/main/README_CN.md&quot;&gt;README&lt;/a&gt;
可以看到使用modelscope调用一个大模型非常简单：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;from modelscope import AutoModelForCausalLM, AutoTokenizer
from modelscope import GenerationConfig

# 可选的模型包括: &quot;qwen/Qwen-7B-Chat&quot;, &quot;qwen/Qwen-14B-Chat&quot;
tokenizer = AutoTokenizer.from_pretrained(&quot;qwen/Qwen-7B-Chat&quot;, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(&quot;qwen/Qwen-7B-Chat&quot;, device_map=&quot;auto&quot;, trust_remote_code=True, fp16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(&quot;Qwen/Qwen-7B-Chat&quot;, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参

response, history = model.chat(tokenizer, &quot;你好&quot;, history=None)
print(response)
response, history = model.chat(tokenizer, &quot;浙江的省会在哪里？&quot;, history=history) 
print(response)
response, history = model.chat(tokenizer, &quot;它有什么好玩的景点&quot;, history=history)
print(response)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;只需要加载分词器，模型以及一些配置信息等，即可对话大模型了。可以看到上述对话中，不断将history传入到模型中，实现一个多轮对话的效果。但是当前的情况下只适合单机一个人使用，此时你传入的字符串，会经过chat方法映射成tokenID序列交给模型推理；当把模型交给推理服务端，作为一个服务的时候，此时将没有chat方法了，只会加载模型本身，模型的输入就是tokenID的序列。&lt;/p&gt;
&lt;p&gt;当我们在服务端是需要加速模型推理时，可以使用vLLM。它可以实现加载一次模型，并且在多线程的方式实现排队，对外提供http服务。当多个并发请求到来时，它会在内存里面把多个请求的prompt拼装成一个batch送入到达模型里面进行批推理，实现更高的吞吐。&lt;/p&gt;
&lt;p&gt;安装vLLM库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;pip install vllm
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;tips：在魔搭社区可以看到Qwen1.8B模型有多个版本：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
这里主要分文预训练模型和Chat模型，其中预训练模型是在超大规模的预训练数据上进行训练得到。预训练数据类型多样，覆盖广泛，包括大量网络文本、专业书籍、代码等。它与Chat模型的区别是：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;1.8B预训练版本，训练数据的方式：
首先给出语料，例如:鸟纲的特征是有羽毛、喙没有牙齿、蛋有硬壳、高代谢率、心脏有四室、轻盈但结实的骨骼。所有鸟类的前肢都进化成翼，大部分也能够飞翔。它们有独特的消化系统及呼吸系统，很适合飞行。

模型输入：鸟纲的特征是有羽毛、喙没有牙齿、蛋有硬壳、高代谢率、心脏有四室、轻盈但结实的骨骼。
模型输出：所有鸟类的前肢都进化成翼，大部分也能够飞翔。它们有独特的消化系统及呼吸系统，很适合飞行。&amp;#x3C;|endoftext|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它相当于将语料拆分成前后两段，给模型前一段文字，然后去预测后面一段文字。而Chat版本是在预训练版本进行微调训练数据，这个数据通常需要人工标注：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;模型输入：&amp;#x3C;|im_start|&gt;system\nyou are a helperful assitant!\n&amp;#x3C;|im_end|&gt;&amp;#x3C;|im_start|&gt;user\n了解鸟类特征吗？\n&amp;#x3C;|im_end|&gt;&amp;#x3C;|im_start|&gt;assitant:\n
模型输出：鸟纲的特征是有羽毛、喙没有牙齿、蛋有硬壳、高代谢率、心脏有四室、轻盈但结实的骨骼。&amp;#x3C;|im_end|&gt;&amp;#x3C;|endoftext|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到模型的输入和输出都是有一定格式控制的。首先是系统指令（给你一个身份），其次是用户的问题，最后是助手的回答，模型输出的内容就是助手的回答。当你需要进行微调时，你可以修改适用于自己的系统指令。&lt;/p&gt;
&lt;h1&gt;vLLM加载&lt;/h1&gt;
&lt;p&gt;了解完Qwen的模型输入输出结构后，接下来学习vLLM是如何进行推理的。
首先定义Qwen的特殊token:&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 通义千问的特殊token
IMSTART=&apos;&amp;#x3C;|im_start|&gt;&apos;  
IMEND=&apos;&amp;#x3C;|im_end|&gt;&apos;
ENDOFTEXT=&apos;&amp;#x3C;|endoftext|&gt;&apos;     # EOS以及PAD都是它
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来就是加载本地或者网络模型以及相关的内容：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if os.path.exists(model_dir):
    local_model_dir = model_dir
    print(f&quot;Using local model directory: {local_model_dir}&quot;)
else:
    print(f&quot;Model not found locally, downloading from ModelScope: {model_dir}&quot;)
    local_model_dir = snapshot_download(model_dir)
    # 模型下载
    snapshot_download(model_dir)

self.generation_config = GenerationConfig.from_pretrained(model_dir,trust_remote_code=True)

# 加载分词器
self.tokenizer=AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)
self.tokenizer.eos_token_id=self.generation_config.eos_token_id
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;准备好推理终止词，遇到这些词停止继续推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;self.stop_words_ids=[self.tokenizer.im_start_id,self.tokenizer.im_end_id,self.tokenizer.eos_token_id]
# stop_words_ids:{} [151644, 151645, 151643]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;使用vLLM加载模型：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;os.environ[&apos;VLLM_USE_MODELSCOPE&apos;]=&apos;True&apos; #这是一个环境变量设置，告诉vLLM 框架：“我接下来要加载的模型可能来自 ModelScope（阿里云的模型开放平台），请启用对 ModelScope 模型格式的支持。
self.model=LLM(model=model_dir,
	              tokenizer=model_dir,
	              tensor_parallel_size=tensor_parallel_size,
	              trust_remote_code=True,
	              quantization=quantization,
	              gpu_memory_utilization=gpu_memory_utilization, # 0.6
	              dtype=dtype)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;em&gt;tips：当发现模型启动时显存不足时，可能与gpu_memory_utilization有关。&lt;/em&gt;&lt;/p&gt;
&lt;h1&gt;vLLM模型推理&lt;/h1&gt;
&lt;p&gt;接下来进行聊天推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def chat(self,query,history=None,system=&quot;You are a helpful assistant.&quot;,extra_stop_words_ids=[]):
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;函数可以传入问题，历史对话，系统指令以及额外的推理停止词（如果你想在你指定的地方停止推理）。&lt;/p&gt;
&lt;p&gt;首先构造promt：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;prompt_text,prompt_tokens=_build_prompt(self.generation_config,self.tokenizer,query,history=history,system=system)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造完成会返回提示词文本和编码tokenID序列，打印结果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;提问:你好 
&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
你好&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;assistant

[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 108386, 151645, 198, 151644, 77091, 198]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来配置vLLM请求：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sampling_params=SamplingParams(stop_token_ids=stop_words_ids, 
                  early_stopping=False,
                  top_p=self.generation_config.top_p,
                  top_k=-1 if self.generation_config.top_k == 0 else self.generation_config.top_k,
                  temperature=self.generation_config.temperature,
                  repetition_penalty=self.generation_config.repetition_penalty,
                  max_tokens=self.generation_config.max_new_tokens)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这些就是模型的配置参数等信息，然后调用vLLM执行推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 调用vLLM执行推理（批次大小1）
req_outputs=self.model.generate(prompt_token_ids=[prompt_tokens],sampling_params=sampling_params,use_tqdm=False) # use_tqdm禁止进度条
req_output=req_outputs[0]    
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将prompt的tokenID序列和相关配置参数传入进行推理得到结果。打印返回信息查看结构：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# transformer模型的原生返回, 打开注释看一下原始推理结果
print(&quot;req_output:&quot;,req_output)
print(&quot;req_outputs:&quot;,req_outputs)
raw_response=req_output.outputs[0].text
print(&quot;raw_response:&quot;,raw_response)
#req_output: RequestOutput(request_id=0, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 108386, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=&apos;你好！很高兴为你解答问题。有什么我可以帮助你的吗？&apos;, token_ids=array(&apos;l&apos;, [108386, 6313, 112169, 106184, 106185, 86119, 1773, 104139, 109944, 100364, 103929, 101037, 11319, 151645]), cumulative_logprob=None, logprobs=None, finish_reason=stop, stop_reason=151645)], finished=True, metrics=RequestMetrics(arrival_time=1762318798.8988488, last_token_time=1762318798.8988488, first_scheduled_time=1762318798.899235, first_token_time=1762318798.9658892, time_in_queue=0.00038623809814453125, finished_time=1762318799.0134034, scheduler_time=0.001102657988667488, model_forward_time=None, model_execute_time=None), lora_request=None)
#req_outputs: [RequestOutput(request_id=0, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 108386, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=&apos;你好！很高兴为你解答问题。有什么我可以帮助你的吗？&apos;, token_ids=array(&apos;l&apos;, [108386, 6313, 112169, 106184, 106185, 86119, 1773, 104139, 109944, 100364, 103929, 101037, 11319, 151645]), cumulative_logprob=None, logprobs=None, finish_reason=stop, stop_reason=151645)], finished=True, metrics=RequestMetrics(arrival_time=1762318798.8988488, last_token_time=1762318798.8988488, first_scheduled_time=1762318798.899235, first_token_time=1762318798.9658892, time_in_queue=0.00038623809814453125, finished_time=1762318799.0134034, scheduler_time=0.001102657988667488, model_forward_time=None, model_execute_time=None), lora_request=None)]
#raw_response: 你好！很高兴为你解答问题。有什么我可以帮助你的吗？
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再移除返回的tokenID序列的停用词，再进行解码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 移除停用词        
response_token_ids=remove_stop_words(req_output.outputs[0].token_ids,stop_words_ids)
response=self.tokenizer.decode(response_token_ids)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;在模型生成的 token 序列中，一旦遇到预设的“停止词”（如对话结束标记），就立即截断后续所有内容，避免把控制标记或多余内容暴露给用户。&lt;/p&gt;
&lt;p&gt;最后整理历史对话：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 整理历史对话
history.append((query,response))
return response,history
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;到此，使用vLLM模型推理的过程就结束了。&lt;/p&gt;
&lt;p&gt;补充，接下来详细解读_build_prompt方法：
首先按ChatML格式构造千问的Prompt，下面是一个简单的CharML对话示例：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
Hello!&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;assistant
Hi! How can I help you today?&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
What&apos;s the weather?&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来，我们学习如何构建这么一个ChatML的Prompt：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 包裹发言内容的token
im_start,im_start_tokens=&apos;&amp;#x3C;|im_start|&gt;&apos;,[tokenizer.im_start_id]
im_end,im_end_tokens=&apos;&amp;#x3C;|im_end|&gt;&apos;,[tokenizer.im_end_id]
# 换行符token
nl_tokens=tokenizer.encode(&quot;\n&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;定义一个方法，用于编码system/user/assistant的一段发言, 格式{role}\n{content}：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def _tokenize_str(role,content): # 返回元组，下标0是文本，下标1是token ids
        return f&quot;{role}\n{content}&quot;,tokenizer.encode(role)+nl_tokens+tokenizer.encode(content)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;计算剩余token数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;left_token_space=generation_config.max_window_size
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后构造Prompt，这是重头戏，由三部分组成，头部+腰部+尾部，最重要的是头尾部，所以优先构造：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# prompt头部: system发言
system_text_part,system_tokens_part=_tokenize_str(&quot;system&quot;, system) # system_tokens_part --&gt;    system\nYou are a helpful assistant.
system_text=f&apos;{im_start}{system_text_part}{im_end}&apos;
system_tokens=im_start_tokens+system_tokens_part+im_end_tokens # &amp;#x3C;|im_start|&gt;system\nYou are a helpful assistant.&amp;#x3C;|im_end|&gt;
left_token_space-=len(system_tokens)
# system_text_part: system\nYou are a helpful assistant.
# system_text: &amp;#x3C;|im_start|&gt;system\nYou are a helpful assistant.&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造尾部：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# prompt尾部: user发言和assistant引导
query_text_part,query_tokens_part=_tokenize_str(&apos;user&apos;, query)
query_tokens_prefix=nl_tokens+ im_start_tokens
query_tokens_suffix=im_end_tokens+nl_tokens+im_start_tokens+tokenizer.encode(&apos;assistant&apos;)+nl_tokens
if len(query_tokens_prefix)+len(query_tokens_part)+len(query_tokens_suffix)&gt;left_token_space: # query太长截断
    query_token_len=left_token_space-len(query_tokens_prefix)-len(query_tokens_suffix)
    query_tokens_part=query_tokens_part[:query_token_len] #对用户提问部分进行截断
    query_text_part=tokenizer.decode(query_tokens_part)
query_tokens=query_tokens_prefix+query_tokens_part+query_tokens_suffix
query_text=f&quot;\n{im_start}{query_text_part}{im_end}\n{im_start}assistant\n&quot;
left_token_space-=len(query_tokens)
# query_text: &amp;#x3C;|im_start|&gt;user\n你好&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造腰部：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# prompt腰部: 历史user+assitant对话
history_text,history_tokens=&apos;&apos;,[]
for hist_query,hist_response in reversed(history):    # 优先采用最近的对话历史
    hist_query_text,hist_query_tokens_part=_tokenize_str(&quot;user&quot;,hist_query) # user\n历史提问
    hist_response_text,hist_response_tokens_part=_tokenize_str(&quot;assistant&quot;,hist_response) # assistant\n历史回答
    # 生成本轮对话
    cur_history_tokens=nl_tokens+im_start_tokens+hist_query_tokens_part+im_end_tokens+nl_tokens+im_start_tokens+hist_response_tokens_part+im_end_tokens
    cur_history_text=f&quot;\n{im_start}{hist_query_text}{im_end}\n{im_start}{hist_response_text}{im_end}&quot;
    # 储存多轮对话
    if len(cur_history_tokens)&amp;#x3C;=left_token_space:
        history_text=cur_history_text+history_text
        history_tokens=cur_history_tokens+history_tokens
        left_token_space-=len(cur_history_tokens)
    else:
        break 
# cur_history_text: &amp;#x3C;|im_start|&gt;user\n你好&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n你好！很高兴为你解答问题。有什么我可以帮助你的吗？&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里有个小细节，遍历历史的时候是反向遍历，优先选择最近的对话历史。&lt;/p&gt;
&lt;p&gt;最后生成完整的Prompt：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 生成完整Prompt
prompt_str=f&apos;{system_text}{history_text}{query_text}&apos;
prompt_tokens=system_tokens+history_tokens+query_tokens
# prompt_str: &amp;#x3C;|im_start|&gt;system\nYou are a helpful assistant.&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;user\n你好&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n你好！很高兴为你解答问题。有什么我可以帮助你的吗？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;user\n你是谁？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;停用词清理代码，从后往前清理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 停用词清理
def remove_stop_words(token_ids,stop_words_ids):
    token_ids=copy.deepcopy(token_ids)
    while len(token_ids)&gt;0:
        if token_ids[-1] in stop_words_ids:
            token_ids.pop(-1)
        else:
            break
    return token_ids
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（二）大模型vLLM推理</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-2/llm_blogs-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-2/llm_blogs-2</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 15:35:00 GMT</pubDate><content:encoded>&lt;p&gt;首先查看Qwen的&lt;a href=&quot;https://github.com/QwenLM/Qwen/blob/main/README_CN.md&quot;&gt;README&lt;/a&gt;
可以看到使用modelscope调用一个大模型非常简单：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;from modelscope import AutoModelForCausalLM, AutoTokenizer
from modelscope import GenerationConfig

# 可选的模型包括: &quot;qwen/Qwen-7B-Chat&quot;, &quot;qwen/Qwen-14B-Chat&quot;
tokenizer = AutoTokenizer.from_pretrained(&quot;qwen/Qwen-7B-Chat&quot;, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(&quot;qwen/Qwen-7B-Chat&quot;, device_map=&quot;auto&quot;, trust_remote_code=True, fp16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(&quot;Qwen/Qwen-7B-Chat&quot;, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参

response, history = model.chat(tokenizer, &quot;你好&quot;, history=None)
print(response)
response, history = model.chat(tokenizer, &quot;浙江的省会在哪里？&quot;, history=history) 
print(response)
response, history = model.chat(tokenizer, &quot;它有什么好玩的景点&quot;, history=history)
print(response)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;只需要加载分词器，模型以及一些配置信息等，即可对话大模型了。可以看到上述对话中，不断将history传入到模型中，实现一个多轮对话的效果。但是当前的情况下只适合单机一个人使用，此时你传入的字符串，会经过chat方法映射成tokenID序列交给模型推理；当把模型交给推理服务端，作为一个服务的时候，此时将没有chat方法了，只会加载模型本身，模型的输入就是tokenID的序列。&lt;/p&gt;
&lt;p&gt;当我们在服务端是需要加速模型推理时，可以使用vLLM。它可以实现加载一次模型，并且在多线程的方式实现排队，对外提供http服务。当多个并发请求到来时，它会在内存里面把多个请求的prompt拼装成一个batch送入到达模型里面进行批推理，实现更高的吞吐。&lt;/p&gt;
&lt;p&gt;安装vLLM库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;pip install vllm
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;tips：在魔搭社区可以看到Qwen1.8B模型有多个版本：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
这里主要分文预训练模型和Chat模型，其中预训练模型是在超大规模的预训练数据上进行训练得到。预训练数据类型多样，覆盖广泛，包括大量网络文本、专业书籍、代码等。它与Chat模型的区别是：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;1.8B预训练版本，训练数据的方式：
首先给出语料，例如:鸟纲的特征是有羽毛、喙没有牙齿、蛋有硬壳、高代谢率、心脏有四室、轻盈但结实的骨骼。所有鸟类的前肢都进化成翼，大部分也能够飞翔。它们有独特的消化系统及呼吸系统，很适合飞行。

模型输入：鸟纲的特征是有羽毛、喙没有牙齿、蛋有硬壳、高代谢率、心脏有四室、轻盈但结实的骨骼。
模型输出：所有鸟类的前肢都进化成翼，大部分也能够飞翔。它们有独特的消化系统及呼吸系统，很适合飞行。&amp;#x3C;|endoftext|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它相当于将语料拆分成前后两段，给模型前一段文字，然后去预测后面一段文字。而Chat版本是在预训练版本进行微调训练数据，这个数据通常需要人工标注：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;模型输入：&amp;#x3C;|im_start|&gt;system\nyou are a helperful assitant!\n&amp;#x3C;|im_end|&gt;&amp;#x3C;|im_start|&gt;user\n了解鸟类特征吗？\n&amp;#x3C;|im_end|&gt;&amp;#x3C;|im_start|&gt;assitant:\n
模型输出：鸟纲的特征是有羽毛、喙没有牙齿、蛋有硬壳、高代谢率、心脏有四室、轻盈但结实的骨骼。&amp;#x3C;|im_end|&gt;&amp;#x3C;|endoftext|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到模型的输入和输出都是有一定格式控制的。首先是系统指令（给你一个身份），其次是用户的问题，最后是助手的回答，模型输出的内容就是助手的回答。当你需要进行微调时，你可以修改适用于自己的系统指令。&lt;/p&gt;
&lt;h1&gt;vLLM加载&lt;/h1&gt;
&lt;p&gt;了解完Qwen的模型输入输出结构后，接下来学习vLLM是如何进行推理的。
首先定义Qwen的特殊token:&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 通义千问的特殊token
IMSTART=&apos;&amp;#x3C;|im_start|&gt;&apos;  
IMEND=&apos;&amp;#x3C;|im_end|&gt;&apos;
ENDOFTEXT=&apos;&amp;#x3C;|endoftext|&gt;&apos;     # EOS以及PAD都是它
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来就是加载本地或者网络模型以及相关的内容：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if os.path.exists(model_dir):
    local_model_dir = model_dir
    print(f&quot;Using local model directory: {local_model_dir}&quot;)
else:
    print(f&quot;Model not found locally, downloading from ModelScope: {model_dir}&quot;)
    local_model_dir = snapshot_download(model_dir)
    # 模型下载
    snapshot_download(model_dir)

self.generation_config = GenerationConfig.from_pretrained(model_dir,trust_remote_code=True)

# 加载分词器
self.tokenizer=AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)
self.tokenizer.eos_token_id=self.generation_config.eos_token_id
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;准备好推理终止词，遇到这些词停止继续推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;self.stop_words_ids=[self.tokenizer.im_start_id,self.tokenizer.im_end_id,self.tokenizer.eos_token_id]
# stop_words_ids:{} [151644, 151645, 151643]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;使用vLLM加载模型：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;os.environ[&apos;VLLM_USE_MODELSCOPE&apos;]=&apos;True&apos; #这是一个环境变量设置，告诉vLLM 框架：“我接下来要加载的模型可能来自 ModelScope（阿里云的模型开放平台），请启用对 ModelScope 模型格式的支持。
self.model=LLM(model=model_dir,
	              tokenizer=model_dir,
	              tensor_parallel_size=tensor_parallel_size,
	              trust_remote_code=True,
	              quantization=quantization,
	              gpu_memory_utilization=gpu_memory_utilization, # 0.6
	              dtype=dtype)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;em&gt;tips：当发现模型启动时显存不足时，可能与gpu_memory_utilization有关。&lt;/em&gt;&lt;/p&gt;
&lt;h1&gt;vLLM模型推理&lt;/h1&gt;
&lt;p&gt;接下来进行聊天推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def chat(self,query,history=None,system=&quot;You are a helpful assistant.&quot;,extra_stop_words_ids=[]):
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;函数可以传入问题，历史对话，系统指令以及额外的推理停止词（如果你想在你指定的地方停止推理）。&lt;/p&gt;
&lt;p&gt;首先构造promt：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;prompt_text,prompt_tokens=_build_prompt(self.generation_config,self.tokenizer,query,history=history,system=system)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造完成会返回提示词文本和编码tokenID序列，打印结果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;提问:你好 
&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
你好&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;assistant

[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 108386, 151645, 198, 151644, 77091, 198]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来配置vLLM请求：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sampling_params=SamplingParams(stop_token_ids=stop_words_ids, 
                  early_stopping=False,
                  top_p=self.generation_config.top_p,
                  top_k=-1 if self.generation_config.top_k == 0 else self.generation_config.top_k,
                  temperature=self.generation_config.temperature,
                  repetition_penalty=self.generation_config.repetition_penalty,
                  max_tokens=self.generation_config.max_new_tokens)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这些就是模型的配置参数等信息，然后调用vLLM执行推理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 调用vLLM执行推理（批次大小1）
req_outputs=self.model.generate(prompt_token_ids=[prompt_tokens],sampling_params=sampling_params,use_tqdm=False) # use_tqdm禁止进度条
req_output=req_outputs[0]    
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将prompt的tokenID序列和相关配置参数传入进行推理得到结果。打印返回信息查看结构：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# transformer模型的原生返回, 打开注释看一下原始推理结果
print(&quot;req_output:&quot;,req_output)
print(&quot;req_outputs:&quot;,req_outputs)
raw_response=req_output.outputs[0].text
print(&quot;raw_response:&quot;,raw_response)
#req_output: RequestOutput(request_id=0, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 108386, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=&apos;你好！很高兴为你解答问题。有什么我可以帮助你的吗？&apos;, token_ids=array(&apos;l&apos;, [108386, 6313, 112169, 106184, 106185, 86119, 1773, 104139, 109944, 100364, 103929, 101037, 11319, 151645]), cumulative_logprob=None, logprobs=None, finish_reason=stop, stop_reason=151645)], finished=True, metrics=RequestMetrics(arrival_time=1762318798.8988488, last_token_time=1762318798.8988488, first_scheduled_time=1762318798.899235, first_token_time=1762318798.9658892, time_in_queue=0.00038623809814453125, finished_time=1762318799.0134034, scheduler_time=0.001102657988667488, model_forward_time=None, model_execute_time=None), lora_request=None)
#req_outputs: [RequestOutput(request_id=0, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 108386, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=&apos;你好！很高兴为你解答问题。有什么我可以帮助你的吗？&apos;, token_ids=array(&apos;l&apos;, [108386, 6313, 112169, 106184, 106185, 86119, 1773, 104139, 109944, 100364, 103929, 101037, 11319, 151645]), cumulative_logprob=None, logprobs=None, finish_reason=stop, stop_reason=151645)], finished=True, metrics=RequestMetrics(arrival_time=1762318798.8988488, last_token_time=1762318798.8988488, first_scheduled_time=1762318798.899235, first_token_time=1762318798.9658892, time_in_queue=0.00038623809814453125, finished_time=1762318799.0134034, scheduler_time=0.001102657988667488, model_forward_time=None, model_execute_time=None), lora_request=None)]
#raw_response: 你好！很高兴为你解答问题。有什么我可以帮助你的吗？
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再移除返回的tokenID序列的停用词，再进行解码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 移除停用词        
response_token_ids=remove_stop_words(req_output.outputs[0].token_ids,stop_words_ids)
response=self.tokenizer.decode(response_token_ids)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;在模型生成的 token 序列中，一旦遇到预设的“停止词”（如对话结束标记），就立即截断后续所有内容，避免把控制标记或多余内容暴露给用户。&lt;/p&gt;
&lt;p&gt;最后整理历史对话：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 整理历史对话
history.append((query,response))
return response,history
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;到此，使用vLLM模型推理的过程就结束了。&lt;/p&gt;
&lt;p&gt;补充，接下来详细解读_build_prompt方法：
首先按ChatML格式构造千问的Prompt，下面是一个简单的CharML对话示例：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;&amp;#x3C;|im_start|&gt;system
You are a helpful assistant.&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
Hello!&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;assistant
Hi! How can I help you today?&amp;#x3C;|im_end|&gt;
&amp;#x3C;|im_start|&gt;user
What&apos;s the weather?&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来，我们学习如何构建这么一个ChatML的Prompt：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 包裹发言内容的token
im_start,im_start_tokens=&apos;&amp;#x3C;|im_start|&gt;&apos;,[tokenizer.im_start_id]
im_end,im_end_tokens=&apos;&amp;#x3C;|im_end|&gt;&apos;,[tokenizer.im_end_id]
# 换行符token
nl_tokens=tokenizer.encode(&quot;\n&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;定义一个方法，用于编码system/user/assistant的一段发言, 格式{role}\n{content}：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def _tokenize_str(role,content): # 返回元组，下标0是文本，下标1是token ids
        return f&quot;{role}\n{content}&quot;,tokenizer.encode(role)+nl_tokens+tokenizer.encode(content)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;计算剩余token数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;left_token_space=generation_config.max_window_size
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后构造Prompt，这是重头戏，由三部分组成，头部+腰部+尾部，最重要的是头尾部，所以优先构造：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# prompt头部: system发言
system_text_part,system_tokens_part=_tokenize_str(&quot;system&quot;, system) # system_tokens_part --&gt;    system\nYou are a helpful assistant.
system_text=f&apos;{im_start}{system_text_part}{im_end}&apos;
system_tokens=im_start_tokens+system_tokens_part+im_end_tokens # &amp;#x3C;|im_start|&gt;system\nYou are a helpful assistant.&amp;#x3C;|im_end|&gt;
left_token_space-=len(system_tokens)
# system_text_part: system\nYou are a helpful assistant.
# system_text: &amp;#x3C;|im_start|&gt;system\nYou are a helpful assistant.&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造尾部：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# prompt尾部: user发言和assistant引导
query_text_part,query_tokens_part=_tokenize_str(&apos;user&apos;, query)
query_tokens_prefix=nl_tokens+ im_start_tokens
query_tokens_suffix=im_end_tokens+nl_tokens+im_start_tokens+tokenizer.encode(&apos;assistant&apos;)+nl_tokens
if len(query_tokens_prefix)+len(query_tokens_part)+len(query_tokens_suffix)&gt;left_token_space: # query太长截断
    query_token_len=left_token_space-len(query_tokens_prefix)-len(query_tokens_suffix)
    query_tokens_part=query_tokens_part[:query_token_len] #对用户提问部分进行截断
    query_text_part=tokenizer.decode(query_tokens_part)
query_tokens=query_tokens_prefix+query_tokens_part+query_tokens_suffix
query_text=f&quot;\n{im_start}{query_text_part}{im_end}\n{im_start}assistant\n&quot;
left_token_space-=len(query_tokens)
# query_text: &amp;#x3C;|im_start|&gt;user\n你好&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造腰部：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# prompt腰部: 历史user+assitant对话
history_text,history_tokens=&apos;&apos;,[]
for hist_query,hist_response in reversed(history):    # 优先采用最近的对话历史
    hist_query_text,hist_query_tokens_part=_tokenize_str(&quot;user&quot;,hist_query) # user\n历史提问
    hist_response_text,hist_response_tokens_part=_tokenize_str(&quot;assistant&quot;,hist_response) # assistant\n历史回答
    # 生成本轮对话
    cur_history_tokens=nl_tokens+im_start_tokens+hist_query_tokens_part+im_end_tokens+nl_tokens+im_start_tokens+hist_response_tokens_part+im_end_tokens
    cur_history_text=f&quot;\n{im_start}{hist_query_text}{im_end}\n{im_start}{hist_response_text}{im_end}&quot;
    # 储存多轮对话
    if len(cur_history_tokens)&amp;#x3C;=left_token_space:
        history_text=cur_history_text+history_text
        history_tokens=cur_history_tokens+history_tokens
        left_token_space-=len(cur_history_tokens)
    else:
        break 
# cur_history_text: &amp;#x3C;|im_start|&gt;user\n你好&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n你好！很高兴为你解答问题。有什么我可以帮助你的吗？&amp;#x3C;|im_end|&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里有个小细节，遍历历史的时候是反向遍历，优先选择最近的对话历史。&lt;/p&gt;
&lt;p&gt;最后生成完整的Prompt：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 生成完整Prompt
prompt_str=f&apos;{system_text}{history_text}{query_text}&apos;
prompt_tokens=system_tokens+history_tokens+query_tokens
# prompt_str: &amp;#x3C;|im_start|&gt;system\nYou are a helpful assistant.&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;user\n你好&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant\n你好！很高兴为你解答问题。有什么我可以帮助你的吗？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;user\n你是谁？&amp;#x3C;|im_end|&gt;\n&amp;#x3C;|im_start|&gt;assistant
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;停用词清理代码，从后往前清理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 停用词清理
def remove_stop_words(token_ids,stop_words_ids):
    token_ids=copy.deepcopy(token_ids)
    while len(token_ids)&gt;0:
        if token_ids[-1] in stop_words_ids:
            token_ids.pop(-1)
        else:
            break
    return token_ids
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>大模型学习（一）通义千问1.8B大模型微调</title><link>https://astro-pure.js.org/blog/llm_blogs/llm_blogs-1/llm_blogs-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/llm_blogs/llm_blogs-1/llm_blogs-1</guid><description>记录LLM的学习。</description><pubDate>Fri, 27 Feb 2026 15:30:00 GMT</pubDate><content:encoded>&lt;p&gt;感谢B站UP的教程&lt;a href=&quot;https://www.bilibili.com/video/BV16a4y1z7LY?vd_source=52455a50a39ab9ee183496a6de048a09&amp;#x26;spm_id_from=333.788.videopod.sections&quot;&gt;大模型系列&lt;/a&gt;，老师讲的很不错！&lt;/p&gt;
&lt;h1&gt;环境配置&lt;/h1&gt;
&lt;p&gt;当前使用的环境为：Python3.10 torch2.4.2 modelscope1.13.0 transformers4.57.1 vllm0.6.0 CUDA12.1
显卡使用RTX 3090&lt;/p&gt;
&lt;h1&gt;下载模型&lt;/h1&gt;
&lt;p&gt;我是使用国内的魔搭社区下载Qwen1.8B的量化模型，下载速度比较快：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;from modelscope import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer

model_dir = snapshot_download(&apos;qwen/Qwen-1_8B-Chat-Int4&apos;, cache_dir=&quot;./Models&quot;)

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    device_map=&quot;auto&quot;,
    trust_remote_code=True
).eval()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码中，指定下载的模型是“&lt;strong&gt;qwen/Qwen-1_8B-Chat-Int4&lt;/strong&gt;”，下载到指定目录“&lt;strong&gt;./Models&lt;/strong&gt;”中。&lt;/p&gt;
&lt;p&gt;&lt;em&gt;tips：模型量化是一种模型压缩和加速技术，通过将模型中的高精度数值（通常是32位浮点数，FP32）转换为低精度表示（如8位整数，INT8，或更低的位数），来减少模型的存储需求、计算量和推理延迟，同时尽量保持模型的精度。例如本实验使用的就是Int 4量化模型。量化具体介绍可以看&lt;a href=&quot;https://blog.csdn.net/u013172930/article/details/147386286?ops_request_misc=%257B%2522request%255Fid%2522%253A%25225eef35c1bf9fb2940b66fc701d6c8ee7%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&amp;#x26;request_id=5eef35c1bf9fb2940b66fc701d6c8ee7&amp;#x26;biz_id=0&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-147386286-null-null.142%5Ev102%5Epc_search_result_base7&amp;#x26;utm_term=%E4%BB%80%E4%B9%88%E6%98%AF%E6%A8%A1%E5%9E%8B%E9%87%8F%E5%8C%96&amp;#x26;spm=1018.2226.3001.4187&quot;&gt;博客&lt;/a&gt;。&lt;strong&gt;挖个坑，后面补一下量化的实验&lt;/strong&gt;&lt;/em&gt;。&lt;/p&gt;
&lt;p&gt;下载好的文件内容如下：
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
assets/：存放模型运行所需的资源文件，如 CUDA 自定义算子等二进制依赖。
cache_autogpt_cuda_256.cpp 和 cache_autogpt_cuda_kernel_2...：用于 GPU 加速的 C++/CUDA 内核源码，优化 INT4 量化模型在显卡上的推理性能。
configuration_qwen.py：定义 Qwen 模型的 Python 配置类，包含模型结构参数。
config.json：以 JSON 格式存储模型的核心配置信息，如层数、隐藏维度等。
cpp_kernels.py：Python 接口文件，用于加载和调用上述 CUDA 内核。
generation_config.json：指定文本生成时的默认参数，如最大长度、温度、top_p 等。
LICENSE.md：模型的开源许可证文件，说明使用条款。
model.safetensors：模型权重文件，采用安全高效的 safetensors 格式存储（替代传统的 .bin）。
modeling_qwen.py：Qwen 模型的主实现代码，包含前向传播逻辑和网络结构。
NOTICE.md：法律或技术声明文件，通常包含版权、第三方依赖等信息。
qwen_generation_utils.py：提供文本生成相关的工具函数，如采样、停止条件判断等。
qwen.tokenized：分词器使用的词汇表或缓存文件，辅助快速分词。
quantize_config.json：记录模型的量化配置，如量化位宽（INT4）、算法类型（如 GPTQ）等。
README.md：模型的使用说明文档，包含简介、安装、示例等信息。
tokenization_qwen.py：Qwen 分词器的实现代码，负责文本与 token ID 之间的转换。
tokenizer_config.json：分词器的配置文件，定义特殊 token、填充策略等参数。&lt;/p&gt;
&lt;h1&gt;提示词工程&lt;/h1&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 城市数据
with open(&apos;city.txt&apos;,&apos;r&apos;,encoding=&apos;utf-8&apos;) as fp:
    city_list=fp.readlines()
    city_list=[line.strip().split(&apos; &apos;)[1] for line in city_list]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;首先读取city.txt文件，文件内容如下：
&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
strip()是一个字符串方法，会跳过字符串前后的空格、\tab、\n等，例如&quot;  Hello, World! \n&quot;经过处理后成为&quot;Hello, World!&quot;。最终得到的city_list格式如下：
&lt;img src=&quot;3.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
接下来进行提示词生成，首先定义一个提示词模版：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;prompt_template=&apos;&apos;&apos;
给定一句话：“%s”，请你按步骤要求工作。

步骤1：识别这句话中的城市和日期共2个信息
步骤2：根据城市和日期信息，生成JSON字符串，格式为{&quot;city&quot;:城市,&quot;date&quot;:日期}

请问，这个JSON字符串是：
&apos;&apos;&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来我们可以往提示词模板里面的%s填充内容：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Q=&apos;青岛4月6日下雨么?&apos;
prompt_template%(Q,)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;得到的结果如下：
&lt;img src=&quot;4.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
接下来我们需要调用大模型生成微调数据，在此之前需要先了解通义千问数据集的格式：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Qwen的SFT数据格式要求:

[
  {
    &quot;id&quot;: &quot;identity_0&quot;,
    &quot;conversations&quot;: [
      {
        &quot;from&quot;: &quot;user&quot;,
        &quot;value&quot;: &quot;你好&quot;
      },
      {
        &quot;from&quot;: &quot;assistant&quot;,
        &quot;value&quot;: &quot;我是一个语言模型，我叫通义千问。&quot;
      }
    ]
  }
]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;em&gt;tips:
必须交替出现：对话必须严格按 user → assistant → user → assistant → ... 的顺序。
必须以 user 开头：每段对话应由用户发起。
不能连续两个 user 或两个 assistant：这会被视为格式错误。
每轮对话是一个独立对象：包含 &quot;from&quot; 和 &quot;value&quot; 两个字段。&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;接下来生成微调数据集：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;import random
import json
import time 

train_data=[]

Q_list=[
    (&apos;{city}{year}年{month}月{day}日的天气&apos;,&apos;%Y-%m-%d&apos;),
    (&apos;{city}{year}年{month}月{day}号的天气&apos;,&apos;%Y-%m-%d&apos;),
    (&apos;{city}{month}月{day}日的天气&apos;,&apos;%m-%d&apos;),
    (&apos;{city}{month}月{day}号的天气&apos;,&apos;%m-%d&apos;),

    (&apos;{year}年{month}月{day}日{city}的天气&apos;,&apos;%Y-%m-%d&apos;),
    (&apos;{year}年{month}月{day}号{city}的天气&apos;,&apos;%Y-%m-%d&apos;),
    (&apos;{month}月{day}日{city}的天气&apos;,&apos;%m-%d&apos;),
    (&apos;{month}月{day}号{city}的天气&apos;,&apos;%m-%d&apos;),

    (&apos;你们{year}年{month}月{day}日去{city}玩吗？&apos;,&apos;%Y-%m-%d&apos;),
    (&apos;你们{year}年{month}月{day}号去{city}玩么？&apos;,&apos;%Y-%m-%d&apos;),
    (&apos;你们{month}月{day}日去{city}玩吗？&apos;,&apos;%m-%d&apos;),
    (&apos;你们{month}月{day}号去{city}玩吗？&apos;,&apos;%m-%d&apos;),
]

# 生成一批&quot;1月2号&quot;、&quot;1月2日&quot;、&quot;2023年1月2号&quot;, &quot;2023年1月2日&quot;, &quot;2023-02-02&quot;, &quot;03-02&quot;之类的话术, 教会它做日期转换
for i in range(1000):
    Q=Q_list[random.randint(0,len(Q_list)-1)]
    city=city_list[random.randint(0,len(city_list)-1)]
    year=random.randint(1990,2025)
    month=random.randint(1,12)
    day=random.randint(1,28)
    time_str=&apos;{}-{}-{}&apos;.format(year,month,day)
    date_field=time.strftime(Q[1],time.strptime(time_str,&apos;%Y-%m-%d&apos;))
    Q=Q[0].format(city=city,year=year,month=month,day=day) # 问题
    A=json.dumps({&apos;city&apos;:city,&apos;date&apos;:date_field},ensure_ascii=False)  # 回答

    example={
        &apos;id&apos;: &apos;identity_{}&apos;.format(i),
        &apos;conversations&apos;:[
            {
                &apos;from&apos;: &apos;user&apos;,
                &apos;value&apos;: prompt_template%(Q,),
            },
            {
                &apos;from&apos;: &apos;assistant&apos;,
                &apos;value&apos;: A,
            }
        ]
    }
    train_data.append(example)

with open(&apos;train.txt&apos;,&apos;w&apos;,encoding=&apos;utf-8&apos;) as fp:
    fp.write(json.dumps(train_data))
print(&quot;样本数量：&quot;,len(train_data))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;解析json格式的数据集格式如下：
&lt;img src=&quot;5.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h1&gt;微调模型&lt;/h1&gt;
&lt;p&gt;接下来微调模型，生成到output_qwen。Qwen1.8B提供了多种微调方法：
&lt;img src=&quot;6.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;ds_config_zero3.json&lt;/code&gt;&lt;/strong&gt;：DeepSpeed ZeRO-3 分布式训练的配置文件，用于多卡高效训练。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;finetune_ds.sh&lt;/code&gt;&lt;/strong&gt;：使用 DeepSpeed 启动全参数分布式微调的训练脚本。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;finetune_lora_single_gpu.sh&lt;/code&gt;&lt;/strong&gt;：在单张 GPU 上使用 LoRA 进行参数高效微调的脚本。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;finetune_qlora_single_gpu.sh&lt;/code&gt;&lt;/strong&gt;：在单张 GPU 上结合 4-bit 量化与 LoRA（即 QLoRA）进行超低显存微调的脚本。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;finetune_lora_multi_gpu.sh&lt;/code&gt;&lt;/strong&gt;：在多张 GPU 上使用 LoRA 进行并行微调的脚本。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;&lt;code&gt;finetune_qlora_multi_gpu.sh&lt;/code&gt;&lt;/strong&gt;：在多张 GPU 上使用 QLoRA（量化 + LoRA）进行并行微调的脚本。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这里我们使用量化的微调脚本&lt;code&gt;finetune_qlora_single_gpu.sh&lt;/code&gt;，代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;python finetune.py \
  --model_name_or_path $MODEL \
  --data_path $DATA \
  --fp16 True \
  --output_dir output_qwen \
  --num_train_epochs 10 \
  --per_device_train_batch_size 5 \
  --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --evaluation_strategy &quot;no&quot; \
  --save_strategy &quot;steps&quot; \
  --save_steps 1000 \
  --save_total_limit 10 \
  --learning_rate 3e-4 \
  --weight_decay 0.1 \
  --adam_beta2 0.95 \
  --warmup_ratio 0.01 \
  --lr_scheduler_type &quot;cosine&quot; \
  --logging_steps 1 \
  --report_to &quot;none&quot; \
  --model_max_length 512 \
  --lazy_preprocess True \
  --gradient_checkpointing \
  --use_lora \
  --q_lora \
  # --deepspeed finetune/ds_config_zero2.json  不使用分布式训练，所以注释
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;以下是每个参数的解释：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--model_name_or_path $MODEL&lt;/code&gt;&lt;/strong&gt;：指定预训练模型的路径或 Hugging Face 模型 ID，用于加载基础模型权重。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--data_path $DATA&lt;/code&gt;&lt;/strong&gt;：指定训练数据文件或目录的路径，通常为 JSON 或 JSONL 格式的指令微调数据集。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--fp16 True&lt;/code&gt;&lt;/strong&gt;：启用半精度浮点数（FP16）训练，减少显存占用并加速计算（需 GPU 支持）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--output_dir output_qwen&lt;/code&gt;&lt;/strong&gt;：设置模型训练过程中检查点和最终结果的保存目录。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--num_train_epochs 10&lt;/code&gt;&lt;/strong&gt;：指定整个训练数据集将被遍历训练 10 轮。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--per_device_train_batch_size 5&lt;/code&gt;&lt;/strong&gt;：每个 GPU 设备在训练时每次处理 5 个样本。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--per_device_eval_batch_size 1&lt;/code&gt;&lt;/strong&gt;：每个 GPU 设备在评估时每次处理 1 个样本（通常因显存限制设得较小）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--gradient_accumulation_steps 8&lt;/code&gt;&lt;/strong&gt;：每 8 个 mini-batch 累积一次梯度再更新，等效于增大 batch size。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--evaluation_strategy &quot;no&quot;&lt;/code&gt;&lt;/strong&gt;：训练过程中不进行验证评估。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--save_strategy &quot;steps&quot;&lt;/code&gt;&lt;/strong&gt;：按训练步数（而非 epoch）保存模型检查点。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--save_steps 1000&lt;/code&gt;&lt;/strong&gt;：每训练 1000 步保存一次模型。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--save_total_limit 10&lt;/code&gt;&lt;/strong&gt;：最多保留最近的 10 个检查点，避免磁盘爆满。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--learning_rate 3e-4&lt;/code&gt;&lt;/strong&gt;：优化器的学习率设为 0.0003，控制参数更新幅度。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--weight_decay 0.1&lt;/code&gt;&lt;/strong&gt;：L2 正则化系数为 0.1，用于防止过拟合。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--adam_beta2 0.95&lt;/code&gt;&lt;/strong&gt;：Adam 优化器的 β2 参数，控制梯度平方的指数衰减率。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--warmup_ratio 0.01&lt;/code&gt;&lt;/strong&gt;：学习率预热比例，前 1% 的训练步数中线性增加学习率。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--lr_scheduler_type &quot;cosine&quot;&lt;/code&gt;&lt;/strong&gt;：使用余弦退火学习率调度策略，平滑降低学习率。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--logging_steps 1&lt;/code&gt;&lt;/strong&gt;：每 1 步就在日志中记录训练指标（如 loss）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--report_to &quot;none&quot;&lt;/code&gt;&lt;/strong&gt;：不将训练日志上报到任何外部平台（如 TensorBoard、W&amp;#x26;B）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--model_max_length 512&lt;/code&gt;&lt;/strong&gt;：设定模型输入序列的最大长度为 512 个 token，超长部分会被截断。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--lazy_preprocess True&lt;/code&gt;&lt;/strong&gt;：启用懒加载预处理，仅在需要时对数据进行 tokenize，节省内存。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--gradient_checkpointing&lt;/code&gt;&lt;/strong&gt;：启用梯度检查点技术，用时间换空间，显著降低显存占用。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--use_lora&lt;/code&gt;&lt;/strong&gt;：启用 LoRA（低秩适配）微调，只训练少量新增参数，冻结原始模型权重。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;--q_lora&lt;/code&gt;&lt;/strong&gt;：启用 QLoRA，在 4-bit 量化模型基础上应用 LoRA，实现极低显存微调。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;# --deepspeed ...&lt;/code&gt;&lt;/strong&gt;：注释掉的 DeepSpeed 配置，表示当前不使用分布式训练框架。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;接下来指定模型和数据集进行微调：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;bash finetune/finetune_qlora_single_gpu.sh -m /root/shared-nvme/LLM-Learning/Qwen-SFT/Models/qwen/Qwen-1_8B-Chat-Int4 -d ../train.txt
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;该命令使用 QLoRA（量化低秩适应）方法对 Qwen-1.8B-Chat-Int4 模型进行微调，训练数据来自指定的文本文件，微调结果将保存到 output_qwen 目录。
训练结束过程如下：
&lt;img src=&quot;7.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
保存后的模型结构如下：
&lt;img src=&quot;8.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
微调结束后保存的模型目录中各个文件的详细解释：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;adapter_config.json&lt;/code&gt;&lt;/strong&gt;：记录 LoRA 微调的配置信息，包括 LoRA 的秩（rank）、alpha、dropout 等参数，用于在推理时正确加载和应用适配器权重。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;qwen_model.safetensors&lt;/code&gt;&lt;/strong&gt;：存储原始 Qwen 模型的主权重文件，采用 safetensors 格式，安全高效，包含冻结的原始模型参数，通常不被修改。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;qwen.tokenized&lt;/code&gt;&lt;/strong&gt;：分词器使用的词汇表或 token 映射文件，用于将文本转换为 token ID，是分词过程的基础数据。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;README.md&lt;/code&gt;&lt;/strong&gt;：模型说明文档，包含模型简介、微调方法、使用说明、依赖项等信息，方便他人理解和复现。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;special_tokens_map.json&lt;/code&gt;&lt;/strong&gt;：定义特殊 token（如 &lt;code&gt;[PAD]&lt;/code&gt;, &lt;code&gt;[CLS]&lt;/code&gt;, &lt;code&gt;[SEP]&lt;/code&gt; 等）的映射关系，确保分词器能正确处理这些标记。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;tokenization_qwen.py&lt;/code&gt;&lt;/strong&gt;：Qwen 分词器的 Python 实现代码，负责文本与 token ID 之间的转换逻辑。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;tokenizer_config.json&lt;/code&gt;&lt;/strong&gt;：分词器的配置文件，包含分词器类型、最大长度、特殊 token 设置等参数。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;trainer_state.json&lt;/code&gt;&lt;/strong&gt;：训练器状态文件，记录训练过程中的步数、学习率、epoch 进度等元信息，可用于恢复训练。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;&lt;code&gt;training_args.bin&lt;/code&gt;&lt;/strong&gt;： 训练参数的二进制文件，保存了训练时的所有超参数（如 batch size、learning rate 等），由 Hugging Face Trainer 自动生成。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h1&gt;加载SFT后的模型&lt;/h1&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(
    &apos;./Qwen/output_qwen&apos;, # path to the output directory
    device_map=&quot;auto&quot;,
    trust_remote_code=True
).eval()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;em&gt;tips：这里加载模型使用 AutoPeftModelForCausalLM（而不是普通的 AutoModelForCausalLM）来加载模型，是因为你微调时使用了 PEFT（Parameter-Efficient Fine-Tuning）技术，比如 LoRA 或 QLoRA。&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;经过10个epoch后的训练，效果非常可以了：
&lt;img src=&quot;9.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
同样的，模型也不会只回答这种格式的内容：
&lt;img src=&quot;10.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/llm_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>Agent实战（五）MCP的demo</title><link>https://astro-pure.js.org/blog/agent_blogs/agent_blogs-5/agent_blogs-5</link><guid isPermaLink="true">https://astro-pure.js.org/blog/agent_blogs/agent_blogs-5/agent_blogs-5</guid><description>Agent实战</description><pubDate>Wed, 11 Feb 2026 20:40:05 GMT</pubDate><content:encoded>&lt;p&gt;如何让大型语言模型（LLM）具备与外部工具和资源交互的能力变得至关重要。Model Context Protocol（MCP）作为一种新兴的标准化协议，为我们提供了构建这类智能应用的新途径。本文将详细介绍如何使用FastMCP框架构建一个简单的MCP服务，并与Cherry Studio进行集成测试。&lt;/p&gt;
&lt;h1&gt;什么是MCP&lt;/h1&gt;
&lt;p&gt;Model Context Protocol（MCP）是一种开放协议，旨在为AI模型与其运行环境之间建立标准化的通信机制。通过MCP，AI模型可以发现和使用各种工具、资源和提示词，从而扩展其能力边界。MCP的主要特点包括：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;标准化接口&lt;/strong&gt;：提供统一的API规范，便于不同组件间的互操作&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;可扩展性&lt;/strong&gt;：支持自定义工具、资源和提示词的添加&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多传输协议&lt;/strong&gt;：支持多种通信方式，包括HTTP、SSE、stdio等&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;安全性&lt;/strong&gt;：通过明确定义的接口控制模型对系统资源的访问&lt;/li&gt;
&lt;/ul&gt;
&lt;h1&gt;FastMCP框架简介&lt;/h1&gt;
&lt;p&gt;FastMCP是一个基于Python的MCP实现框架，它大大简化了MCP服务的开发过程。通过装饰器模式，开发者可以快速地将普通函数转换为MCP工具、资源或提示词，而无需深入了解底层协议细节。&lt;/p&gt;
&lt;p&gt;FastMCP的核心优势包括：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;简洁的API设计&lt;/li&gt;
&lt;li&gt;多种传输协议支持&lt;/li&gt;
&lt;li&gt;自动化的接口文档生成&lt;/li&gt;
&lt;li&gt;易于集成到现有项目中&lt;/li&gt;
&lt;/ul&gt;
&lt;h1&gt;环境准备&lt;/h1&gt;
&lt;p&gt;在开始之前，我们需要准备以下环境：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Python 3.8+&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;安装FastMCP及相关依赖：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;pip install mcp[cli]
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;安装Cherry Studio（或其他支持MCP的客户端）&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h1&gt;构建MCP服务端&lt;/h1&gt;
&lt;h2&gt;创建基础服务器&lt;/h2&gt;
&lt;p&gt;我们首先创建一个基本的MCP服务器：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from mcp.server.fastmcp import FastMCP

# 创建一个MCP服务器
mcp = FastMCP(&quot;Demo&quot;, json_response=True)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里我们创建了一个名为&quot;Demo&quot;的MCP服务器实例，并启用了JSON响应格式。&lt;/p&gt;
&lt;h2&gt;添加工具功能&lt;/h2&gt;
&lt;p&gt;工具是MCP中最常用的功能之一，它允许模型执行特定的操作。我们添加一个简单的加法工具：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@mcp.tool()
def add(a: int, b: int) -&gt; int:
    &quot;&quot;&quot;将两个数字相加&quot;&quot;&quot;
    return a + b 
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;通过&lt;code&gt;@mcp.tool()&lt;/code&gt;装饰器，我们将普通的Python函数转换为MCP工具，模型可以在需要时调用此工具执行加法运算。&lt;/p&gt;
&lt;h2&gt;添加资源功能&lt;/h2&gt;
&lt;p&gt;资源是可供模型读取的数据源。我们创建一个动态问候资源：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@mcp.resource(&quot;greeting://{name}&quot;)
def get_greeting(name: str) -&gt; str:
    &quot;&quot;&quot;获取个性化问候语&quot;&quot;&quot;
    return f&quot;Hello, {name}!我听说你在学MCP！真是泰裤辣！&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个资源使用URI模板格式定义，其中&lt;code&gt;{name}&lt;/code&gt;是可变参数，允许模型根据不同的名字获取个性化的问候语。&lt;/p&gt;
&lt;h2&gt;添加提示词功能&lt;/h2&gt;
&lt;p&gt;提示词功能可以帮助模型生成更加规范和一致的文本内容：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@mcp.prompt()
def greet_user(name: str, style: str = &quot;friendly&quot;) -&gt; str:
    &quot;&quot;&quot;生成问候语提示词&quot;&quot;&quot;
    styles = {
        &quot;friendly&quot;: &quot;请写一个温暖友好的问候语&quot;,
        &quot;formal&quot;: &quot;请写一个正式专业的问候语&quot;,
        &quot;casual&quot;: &quot;请写一个随意轻松的问候语&quot;,
    }
    
    return f&quot;{styles.get(style, styles[&apos;friendly&apos;])} 给一个叫{name}的人。&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;通过这个提示词功能，模型可以根据不同的风格要求生成相应的问候语。&lt;/p&gt;
&lt;h1&gt;运行MCP服务&lt;/h1&gt;
&lt;p&gt;完成上述功能定义后，我们需要启动MCP服务：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;if __name__ == &quot;__main__&quot;:
    mcp.run(transport=&quot;sse&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里我们选择使用SSE（Server-Sent Events）传输协议运行服务。FastMCP还支持其他传输方式，如&lt;code&gt;streamable-http&lt;/code&gt;和&lt;code&gt;stdio&lt;/code&gt;。&lt;/p&gt;
&lt;h1&gt;与Cherry Studio集成&lt;/h1&gt;
&lt;h2&gt;连接MCP服务&lt;/h2&gt;
&lt;p&gt;在Cherry Studio中，我们需要配置MCP连接参数以连接到我们刚刚创建的服务。
&lt;img src=&quot;1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;调用工具&lt;/h2&gt;
&lt;p&gt;连接成功后，在Cherry Studio中我们可以看到已注册的[add]工具。通过自然语言指令，如&quot;计算10和20的和&quot;，Cherry Studio会自动调用我们的加法工具并返回结果。
&lt;img src=&quot;2.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;访问资源&lt;/h2&gt;
&lt;p&gt;对于资源访问，在SSE传输模式下，我们需要通过MCP协议的标准消息格式来请求资源。在Cherry Studio中，可以通过类似&quot;获取给Alice的问候语&quot;这样的指令来触发对&lt;code&gt;greeting://Alice&lt;/code&gt;资源的访问。&lt;/p&gt;
&lt;h2&gt;使用提示词&lt;/h2&gt;
&lt;p&gt;提示词功能可以帮助模型生成更加规范的内容。当我们请求&quot;用正式的风格给张三写一个问候语&quot;时，Cherry Studio会使用我们的[greet_user]提示词功能来指导模型生成相应的内容。&lt;/p&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;通过本文的介绍，我们学习了如何使用FastMCP框架构建一个简单的MCP服务，并与Cherry Studio进行集成。MCP为我们提供了一种标准化的方式来扩展AI模型的能力，使其能够与外部工具和资源进行交互。&lt;/p&gt;
&lt;p&gt;随着MCP生态的发展，我们可以期待更多创新的应用场景出现。无论是构建智能助手、自动化工作流还是复杂的AI代理系统，MCP都将成为重要的技术基石。&lt;/p&gt;
&lt;p&gt;在未来的学习和实践中，建议深入研究MCP协议的更多高级特性，如权限控制、状态管理、异步操作等，以充分发挥其潜力。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>Agent实战（四）深入理解Agent从输入到输出的信息流</title><link>https://astro-pure.js.org/blog/agent_blogs/agent_blogs-4</link><guid isPermaLink="true">https://astro-pure.js.org/blog/agent_blogs/agent_blogs-4</guid><description>Agent实战</description><pubDate>Wed, 11 Feb 2026 20:40:04 GMT</pubDate><content:encoded>&lt;h2&gt;Agent信息流概述&lt;/h2&gt;
&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/Agent_Project&quot;&gt;Github&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;我们需要了解，当我们给Agent输入一个query时，它会经过哪些处理，最终得到输出呢？通过了解这整个过程，可以使我们对Agent有一个更深入的理解。&lt;/p&gt;
&lt;p&gt;下面以my_agent1.py为例详细解释，测试代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@tool
def send_email(to: str, subject: str, body: str):
    &quot;&quot;&quot;发送邮件 - 该工具可以发送电子邮件给指定收件人
    
    Args:
        to: 收件人邮箱地址或姓名
        subject: 邮件主题
        body: 邮件正文内容
    &quot;&quot;&quot;
    email = {
        &quot;to&quot;: to,
        &quot;subject&quot;: subject,
        &quot;body&quot;: body
    }
    # ...邮件发送逻辑
    print(f&quot;📧 工具执行: send_email(to=&apos;{to}&apos;, subject=&apos;{subject}&apos;, body=&apos;{body}&apos;)&quot;)

    return f&quot;邮件已发送至 {to}&quot;


# 创建 React Agent
agent_executor = create_agent(
    model=llm,
    tools=[send_email],
    system_prompt=&quot;你是一个邮件助手。&quot;
)

# 添加测试输入
if __name__ == &quot;__main__&quot;:
    import asyncio
    async def main():
        # 测试输入
        inputs = {
            &quot;messages&quot;: [
                {&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;请帮我给张三发一封邮件，告诉他会议时间改到明天下午3点了，主题是项目进度同步。&quot;}
            ]
        }
        print(&quot;用户请求:&quot;, inputs[&quot;messages&quot;][0][&quot;content&quot;])
        print(&quot;\nAgent执行过程:&quot;)
        # 异步流式执行
        async for chunk in agent_executor.astream(inputs, stream_mode=&quot;updates&quot;):
            print(chunk)
        print(&quot;\n执行完成&quot;)
    # 运行异步主函数
    asyncio.run(main())
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;Graph初始化过程&lt;/h2&gt;
&lt;p&gt;首先经过下面的代码会创建一个包含模型节点和工具节点的graph图：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;agent_executor = create_agent(
    model=llm,
    tools=[send_email],
    system_prompt=&quot;你是一个邮件助手。&quot;
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;在这个过程中会添加节点：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;1. 添加模型节点：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;graph.add_node(&quot;model&quot;, RunnableCallable(model_node, amodel_node, trace=False))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;2. 添加工具节点：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;graph.add_node(&quot;tools&quot;, tool_node)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;3. 确定入口节点&lt;/strong&gt;，代码如下，这里就是中间件发挥作用的地方，可以在query输入到模型之前进行一些预处理：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 确定入口节点（在开始时运行一次）：before_agent -&gt; before_model -&gt; model
if middleware_w_before_agent:
    entry_node = f&quot;{middleware_w_before_agent[0].name}.before_agent&quot;
elif middleware_w_before_model:
    entry_node = f&quot;{middleware_w_before_model[0].name}.before_model&quot;
else:
    entry_node = &quot;model&quot;
print(f&quot;🏁 确定入口节点: {entry_node}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🏁 确定入口节点: model
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;4. 确定循环节点的入口：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if middleware_w_before_model:
    loop_entry_node = f&quot;{middleware_w_before_model[0].name}.before_model&quot;
else:
    loop_entry_node = &quot;model&quot;
print(f&quot;🔄 确定循环入口节点: {loop_entry_node}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🔄 确定循环入口节点: model
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;5. 确定循环节点的出口：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 确定循环出口节点（每次迭代结束，可以运行多次）
# 这是after_model或model，但不是after_agent
if middleware_w_after_model:
    loop_exit_node = f&quot;{middleware_w_after_model[0].name}.after_model&quot;
else:
    loop_exit_node = &quot;model&quot;
print(f&quot;🚪 确定循环出口节点: {loop_exit_node}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🚪 确定循环出口节点: model
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;6. 确定出口节点：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 确定出口节点（最后运行一次）：after_agent或END
if middleware_w_after_agent:
    exit_node = f&quot;{middleware_w_after_agent[-1].name}.after_agent&quot;
else:
    exit_node = END
print(f&quot;🔚 确定最终出口节点: {exit_node}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🔚 确定最终出口节点: __end__
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;7. 添加起始边：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 添加起始边
print(f&quot;🔗 添加起始边: START -&gt; {entry_node}&quot;)
graph.add_edge(START, entry_node)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🔗 添加起始边: START -&gt; model
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;规定了图的起始边是&lt;code&gt;model&lt;/code&gt;，所以&lt;code&gt;query&lt;/code&gt;首先会传递给&lt;code&gt;amodel_node&lt;/code&gt;函数。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;起始边 vs 入口节点 vs 循环入口节点&lt;/strong&gt;
| 概念 | 作用域 | 执行次数 | 主要用途 |
|------|--------|----------|----------|
| &lt;strong&gt;起始边&lt;/strong&gt; | 整个图的物理起点 | 仅1次 | 定义图的执行起点 |
| &lt;strong&gt;入口节点&lt;/strong&gt; | 工作流的逻辑起点 | 仅1次 | 决定从哪个业务节点开始 |
| &lt;strong&gt;循环入口节点&lt;/strong&gt; | 循环迭代的起点 | 可能多次 | Agent思考-行动循环的起点 |&lt;/p&gt;
&lt;p&gt;在我的例子中，流程如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 只有一个工作流，包含循环逻辑
START → 入口节点 → 循环入口节点 → 模型节点
                    ↑          ↓
                    └── 工具节点 ←─┘
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;8. 添加工具相关条件边：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;graph.add_conditional_edges(
    &quot;tools&quot;,  # 起始节点：名为 &quot;tools&quot; 的节点（负责调用工具）
    RunnableCallable(
        _make_tools_to_model_edge(...),  # 条件路由函数：决定下一步去哪
        trace=False,
    ),
    tools_to_model_destinations,  # 允许的目标节点列表（合法跳转范围）
)
print(f&quot;🔗 工具节点条件边目标: {tools_to_model_destinations}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;代码解释：&lt;/strong&gt; 这段代码定义了当&quot;工具节点&quot;（tools）执行完毕后，下一步应该跳转到哪个节点——是回到大模型继续推理，还是直接结束流程。&lt;/p&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🔗 工具节点条件边目标: [&apos;model&apos;]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果解释：&lt;/strong&gt; 输出显示工具节点执行后只能跳转到model节点，这意味着工具执行完成后必须返回模型进行下一步决策。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;流程图表示：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;[model] 
   │
   ├─(需调用工具)──→ [tools] ──→ [model] ──→ ...
   │
   └─(无需工具/完成)──→ [__end__]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;9. 添加模型节点到工具节点的条件边：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 添加从模型节点到工具节点的条件边（核心代理循环）
graph.add_conditional_edges(
    loop_exit_node,
    RunnableCallable(
        _make_model_to_tools_edge(
            model_destination=loop_entry_node,
            structured_output_tools=structured_output_tools,
            end_destination=exit_node,
        ),
        trace=False,
    ),
    model_to_tools_destinations,
)
print(f&quot;🔗 模型节点条件边目标: {model_to_tools_destinations}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🔗 模型节点条件边目标: [&apos;tools&apos;, &apos;__end__&apos;]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果解释：&lt;/strong&gt; 模型节点执行完毕后可以选择调用工具（tools）或者直接结束（&lt;strong&gt;end&lt;/strong&gt;），这是Agent决策的核心逻辑。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;10. 添加其他边：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 添加before_agent中间件边
if middleware_w_before_agent:
    print(&quot;🔗 添加before_agent中间件边&quot;)

# 添加before_model中间件边
if middleware_w_before_model:
    print(&quot;🔗 添加before_model中间件边&quot;)

# 添加after_model中间件边
if middleware_w_after_model:
    print(&quot;🔗 添加after_model中间件边&quot;)

# 添加after_agent中间件边
if middleware_w_after_agent:
    print(&quot;🔗 添加after_agent中间件边&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;11. 最后编译，返回图：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 编译并返回图
print(&quot;✅ 图构建完成，准备编译&quot;)
return graph.compile(
    checkpointer=checkpointer,
    store=store,
    interrupt_before=interrupt_before,
    interrupt_after=interrupt_after,
    debug=debug,
    name=name,
    cache=cache,
).with_config({&quot;recursion_limit&quot;: 10_000})
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;Agent运行过程&lt;/h2&gt;
&lt;h3&gt;第一阶段：初始模型调用&lt;/h3&gt;
&lt;p&gt;当我们将&lt;code&gt;input&lt;/code&gt;传入给&lt;code&gt;graph&lt;/code&gt;时（这里我们以异步为例），会传给&lt;code&gt;amodel_node&lt;/code&gt;方法，在这里，你的&lt;code&gt;query&lt;/code&gt;和&lt;code&gt;system prompt&lt;/code&gt;被包装成一个&lt;code&gt;ModelRequest&lt;/code&gt;对象。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;ModelRequest对象结构（简化）：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;ModelRequest(
    model=ChatOpenAI(
        model_name=&apos;Qwen3-32B&apos;,
        openai_api_base=&apos;https://llmapi.paratera.com/v1/&apos;,
        openai_api_key=&apos;**********&apos;,  # 已脱敏
        default_headers={&apos;Accept&apos;: &apos;application/json&apos;},
        extra_body={&apos;enable_thinking&apos;: False}
    ),
    messages=[
        HumanMessage(
            content=&apos;请帮我给张三发一封邮件，告诉他会议时间改到明天下午3点了，主题是项目进度同步。&apos;
        )
    ],
    system_message=SystemMessage(
        content=&apos;你是一个邮件助手。&apos;
    ),
    tools=[
        StructuredTool(
            name=&apos;send_email&apos;,
            description=&apos;发送邮件 - 该工具可以发送电子邮件给指定收件人\n\nArgs:\n    to: 收件人邮箱地址或姓名\n    subject: 邮件主题\n    body: 邮件正文内容&apos;,
            args_schema=send_email (Pydantic model)
        )
    ],
    tool_choice=None,
    response_format=None,
    state={
        &apos;messages&apos;: [ ... ]  # 与 messages 相同，略
    }
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;代码解释：&lt;/strong&gt; ModelRequest包含了用户query、系统提示词和结构化工具描述等信息，为模型调用做好准备。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;模型绑定过程输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🔧 开始模型绑定过程...
📋 准备绑定工具到模型，工具数量: 1
   1. 工具名: send_email
      描述: 发送邮件 - 该工具可以发送电子邮件给指定收件人

    Args:
        to: 收件人邮箱地址或姓名
        subject: 邮件主题
        body: 邮件正文内容
      参数: [&apos;to&apos;, &apos;subject&apos;, &apos;body&apos;]
🛠️  工具定义详情（将通过API参数传递给模型）:
   工具 1: {&apos;type&apos;: &apos;function&apos;, &apos;function&apos;: {&apos;name&apos;: &apos;send_email&apos;, &apos;description&apos;: &apos;发送邮件 - 该工具可以发送电子邮件给指定收件人\n\n    Args:\n        to: 收件人邮箱地址或姓名\n        subject: 邮件主题\n        body: 邮件正文内容&apos;, &apos;parameters&apos;: {&apos;description&apos;: &apos;发送邮件 - 该工具可 以发送电子邮件给指定收件人\n\nArgs:\n    to: 收件人邮箱地址或姓名\n    subject: 邮件主题\n    body: 邮件正文内容&apos;, &apos;properties&apos;: {&apos;to&apos;: {&apos;title&apos;: &apos;To&apos;, &apos;type&apos;: &apos;string&apos;}, &apos;subject&apos;: {&apos;title&apos;: &apos;Subject&apos;, &apos;type&apos;: &apos;string&apos;}, &apos;body&apos;: {&apos;title&apos;: &apos;Body&apos;, &apos;type&apos;: &apos;string&apos;}}, &apos;required&apos;: [&apos;to&apos;, &apos;subject&apos;, &apos;body&apos;], &apos;title&apos;: &apos;send_email&apos;, &apos;type&apos;: &apos;object&apos;}}}
📦 使用标准方式绑定模型和工具, tool_choice: None
   注意: 工具信息不会注入到prompt中，而是通过API参数传递
   模型绑定参数:
     tool_choice: None
     model_settings: {}
✅ 模型和工具绑定完成 (标准方式)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果解释：&lt;/strong&gt; 这里展示了工具如何被绑定到模型，工具的描述信息会通过API参数传递给大模型，而不是直接注入到prompt中。这里需要注意，为什么在下面的模型调用输出的&lt;code&gt;content&lt;/code&gt;为空，而将工具信息都存储在&lt;code&gt;additional_kwargs&lt;/code&gt;中，这是因为通过使用 &lt;code&gt;bind_tools&lt;/code&gt; 方法时：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;bound_model = request.model.bind_tools(
                final_tools, tool_choice=tool_choice, **request.model_settings
            )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个过程告诉模型：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;有哪些工具可以使用（&lt;code&gt;final_tools&lt;/code&gt;就是工具列表）&lt;/li&gt;
&lt;li&gt;每个工具的名称、描述和参数格式&lt;/li&gt;
&lt;li&gt;工具调用的规范格式&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;现代大语言模型（如Qwen3-32B）被训练成能够生成特定结构的输出。当您绑定工具后，模型知道需要按照特定格式返回工具调用，结构如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;{
  &quot;tool_calls&quot;: [
    {
      &quot;id&quot;: &quot;call_xxx&quot;,
      &quot;function&quot;: {
        &quot;name&quot;: &quot;工具名称&quot;,
        &quot;arguments&quot;: &quot;{\&quot;参数1\&quot;: \&quot;值1\&quot;, \&quot;参数2\&quot;: \&quot;值2\&quot;}&quot;
      },
      &quot;type&quot;: &quot;function&quot;
    }
  ]
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;模型调用输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;💬 准备发送2条消息给模型
📄 发送给模型的消息内容:
   [1] 你是一个邮件助手。
   [2] 请帮我给张三发一封邮件，告诉他会议时间改到明天下午3点了，主题是项目进度同步。
📥 模型返回结果，类型: AIMessage
📋 模型返回详细内容:
   content: &apos;&apos;
   additional_kwargs: {&apos;tool_calls&apos;: [{&apos;id&apos;: &apos;call_79b217f7070943b3bd01bf&apos;, &apos;function&apos;: {&apos;arguments&apos;: &apos;{&quot;to&quot;: &quot;zhangsan@example.com&quot;, &quot;subject&quot;: &quot;项目进度同步&quot;, &quot;body&quot;: &quot;张三，你好！会议时间已经调整到明天下午3点，请准时参加。谢谢！&quot;}&apos;, &apos;name&apos;: &apos;send_email&apos;}, &apos;type&apos;: &apos;function&apos;, &apos;index&apos;: 0}], &apos;refusal&apos;: None}
   检测到 1 个工具调用:
     [1] 工具名: send_email
         ID: call_79b217f7070943b3bd01bf
         参数: {&quot;to&quot;: &quot;zhangsan@example.com&quot;, &quot;subject&quot;: &quot;项目进度同步&quot;, &quot;body&quot;: &quot;张三，你好！会议时间已经调整到明天下午3点，请准时参加。谢谢 ！&quot;}
   response_metadata: {&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 56, &apos;prompt_tokens&apos;: 248, &apos;total_tokens&apos;: 304, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-5efb3cb7-7127-499e-a1b1-84692858dbca&apos;, &apos;service_tier&apos;: None, &apos;finish_reason&apos;: &apos;tool_calls&apos;, &apos;logprobs&apos;: None}
🛠️  模型返回1个工具调用
   工具调用 1: send_email({&apos;to&apos;: &apos;zhangsan@example.com&apos;, &apos;subject&apos;: &apos;项目进度同步&apos;, &apos;body&apos;: &apos;张三，你好！会议时间已经调整到明天下午3点，请准时参加。谢谢！&apos;})
🔄 开始处理模型输出...
   输出类型: AIMessage
   effective_response_format: None
📦 返回标准消息格式
📦 模型输出处理完成，消息数: 1
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果解释：&lt;/strong&gt; 模型识别到需要调用工具，返回了一个包含工具调用的AIMessage，其中content为空，工具调用信息存储在&lt;code&gt;additional_kwargs&lt;/code&gt;中。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;amodel_node返回：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;📤 amodel_node返回更新: [&apos;messages&apos;]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;重要说明：&lt;/strong&gt; 此时模型并没有输出任何文本内容，而是通过&lt;code&gt;additional_kwargs&lt;/code&gt;附带了一个工具调用请求。&lt;/p&gt;
&lt;h3&gt;第二阶段：工具调用决策&lt;/h3&gt;
&lt;p&gt;然后进入条件边决策函数&lt;code&gt;model_to_tools&lt;/code&gt;，决定是否调用工具：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;决策过程输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;🧭 进入条件边决策函数model_to_tools
📨 最后AI消息的工具调用数: 1
🔧 已处理的工具消息数: 0
⏳ 待处理工具调用数: 1
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;决策逻辑代码：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;if pending_tool_calls:
   print(f&quot;🔧 存在待处理工具调用，转向工具节点&quot;)
   result = [
       Send(
           &quot;tools&quot;,
           ToolCallWithContext(
               __type=&quot;tool_call_with_context&quot;,
               tool_call=tool_call,
               state=state,
           ),
       )
       for tool_call in pending_tool_calls
   ]
   print(f&quot;📍 决策结果: 发送{len(result)}个工具调用到tools节点&quot;)
   return result
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;代码解释：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;对每个待处理的&lt;code&gt;tool_call&lt;/code&gt;，创建一个&lt;code&gt;ToolCallWithContext&lt;/code&gt;对象，包含工具调用本身和当前状态上下文&lt;/li&gt;
&lt;li&gt;使用&lt;code&gt;Send(&quot;tools&quot;, ...)&lt;/code&gt;将对象发送到名为&quot;tools&quot;的节点&lt;/li&gt;
&lt;li&gt;最终返回一个Send对象列表&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;决策结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;📍 决策结果: 发送1个工具调用到tools节点
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;第三阶段：工具执行&lt;/h3&gt;
&lt;p&gt;调用我们创建的&lt;code&gt;send_email&lt;/code&gt;方法：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;工具执行输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;📧 工具执行: send_email(to=&apos;zhangsan@example.com&apos;, subject=&apos;项目进度同步&apos;, body=&apos;张三，你好！会议时间已经调整到明天下午3点，请准时参加。谢谢！&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;工具返回结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;{&apos;tools&apos;: {&apos;messages&apos;: [ToolMessage(content=&apos;邮件已发送至 zhangsan@example.com&apos;, name=&apos;send_email&apos;, id=&apos;2c9e0cf3-d138-4e0b-ae48-d7fc063425d0&apos;, tool_call_id=&apos;call_79b217f7070943b3bd01bf&apos;)]}}
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;第四阶段：第二次模型调用&lt;/h3&gt;
&lt;p&gt;此时又要进入amodel_node节点，此时该节点获得了3条消息：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;消息历史：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-json&quot;&gt;{
  &quot;messages&quot;: [
    {
      &quot;type&quot;: &quot;HumanMessage&quot;,
      &quot;content&quot;: &quot;请帮我给张三发一封邮件，告诉他会议时间改到明天下午3点了，主题是项目进度同步。&quot;,
      &quot;id&quot;: &quot;7b1ccde5-6917-4324-b345-3273221fa874&quot;
    },
    {
      &quot;type&quot;: &quot;AIMessage&quot;,
      &quot;content&quot;: &quot;&quot;,
      &quot;additional_kwargs&quot;: {
        &quot;tool_calls&quot;: [
          {
            &quot;id&quot;: &quot;call_79b217f7070943b3bd01bf&quot;,
            &quot;function&quot;: {
              &quot;name&quot;: &quot;send_email&quot;,
              &quot;arguments&quot;: &quot;{\&quot;to\&quot;: \&quot;zhangsan@example.com\&quot;, \&quot;subject\&quot;: \&quot;项目进度同步\&quot;, \&quot;body\&quot;: \&quot;张三，你好！会议时间已经调整到明天下午3点，请准时参加。谢谢！\&quot;}&quot;
            },
            &quot;type&quot;: &quot;function&quot;,
            &quot;index&quot;: 0
          }
        ],
        &quot;refusal&quot;: null
      },
      &quot;response_metadata&quot;: {
        &quot;token_usage&quot;: {
          &quot;completion_tokens&quot;: 56,
          &quot;prompt_tokens&quot;: 248,
          &quot;total_tokens&quot;: 304
        },
        &quot;model_name&quot;: &quot;Qwen3-32B&quot;,
        &quot;finish_reason&quot;: &quot;tool_calls&quot;
      },
      &quot;id&quot;: &quot;lc_run--4dbc579d-ae7d-4d90-8786-9eee8ae767b2-0&quot;,
      &quot;tool_calls&quot;: [
        {
          &quot;name&quot;: &quot;send_email&quot;,
          &quot;args&quot;: {
            &quot;to&quot;: &quot;zhangsan@example.com&quot;,
            &quot;subject&quot;: &quot;项目进度同步&quot;,
            &quot;body&quot;: &quot;张三，你好！会议时间已经调整到明天下午3点，请准时参加。谢谢！&quot;
          },
          &quot;id&quot;: &quot;call_79b217f7070943b3bd01bf&quot;,
          &quot;type&quot;: &quot;tool_call&quot;
        }
      ]
    },
    {
      &quot;type&quot;: &quot;ToolMessage&quot;,
      &quot;content&quot;: &quot;邮件已发送至 zhangsan@example.com&quot;,
      &quot;name&quot;: &quot;send_email&quot;,
      &quot;id&quot;: &quot;2c9e0cf3-d138-4e0b-ae48-d7fc063425d0&quot;,
      &quot;tool_call_id&quot;: &quot;call_79b217f7070943b3bd01bf&quot;
    }
  ]
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;代码解释：&lt;/strong&gt; 现在消息历史包含了完整的对话上下文：用户请求、AI的工具调用决策、工具执行结果。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;模型调用输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;💬 准备发送4条消息给模型
📄 发送给模型的消息内容:
   [1] 你是一个邮件助手。
   [2] 请帮我给张三发一封邮件，告诉他会议时间改到明天下午3点了，主题是项目进度同步。
   [3] 包含1个工具调用
   [4] 邮件已发送至 zhangsan@example.com
📥 模型返回结果，类型: AIMessage
📋 模型返回详细内容:
   content: &apos;邮件已经成功发送给张三，告诉他会议时间调整到了明天下午3点。如果有其他需要，请随时告诉我！&apos;
   additional_kwargs: {&apos;refusal&apos;: None}
   response_metadata: {&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 25, &apos;prompt_tokens&apos;: 327, &apos;total_tokens&apos;: 352, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-0a365738-5955-4439-bbec-a002c2983368&apos;, &apos;service_tier&apos;: None, &apos;finish_reason&apos;: &apos;stop&apos;, &apos;logprobs&apos;: None}
📝 模型返回内容: 邮件已经成功发送给张三，告诉他会议时间调整到了明天下午3点。如果有其他需要，请随时告诉我！
🔄 开始处理模型输出...
   输出类型: AIMessage
   effective_response_format: None
📦 返回标准消息格式
📦 模型输出处理完成，消息数: 1
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果解释：&lt;/strong&gt; 这次模型基于完整的对话历史，生成了最终的用户响应，确认任务已完成。&lt;/p&gt;
&lt;h3&gt;第五阶段：流程结束&lt;/h3&gt;
&lt;p&gt;再次进入条件边决策函数&lt;code&gt;model_to_tools&lt;/code&gt;：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;决策输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;📨 最后AI消息的工具调用数: 0
🔧 已处理的工具消息数: 0
🔚 无工具调用，流程结束，目标: __end__
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果解释：&lt;/strong&gt; 由于没有待处理的工具调用，决策函数决定结束流程，Agent执行完成。&lt;/p&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;整个Agent信息流可以概括为以下步骤：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;初始化&lt;/strong&gt;：构建包含模型节点和工具节点的有向图&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型推理&lt;/strong&gt;：LLM分析用户请求并决定是否需要调用工具&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;工具执行&lt;/strong&gt;：执行具体的工具函数&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;结果整合&lt;/strong&gt;：将工具执行结果返回给模型进行最终响应&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;流程结束&lt;/strong&gt;：当没有更多工具需要调用时结束流程&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;这种设计使得Agent能够灵活地在模型推理和工具执行之间循环，完成复杂的多步任务。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>Agent实战（三）Agent的短期记忆</title><link>https://astro-pure.js.org/blog/agent_blogs/agent_blogs-3</link><guid isPermaLink="true">https://astro-pure.js.org/blog/agent_blogs/agent_blogs-3</guid><description>Agent实战</description><pubDate>Wed, 11 Feb 2026 20:40:03 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/Agent_Project&quot;&gt;Github&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;在构建智能 Agent 时，一个关键挑战是如何让模型在多轮对话中&lt;strong&gt;记住上下文信息&lt;/strong&gt;——比如用户的名字、偏好或之前的请求。然而，大语言模型（LLM）本身是无状态的，每次调用都是独立的。因此，我们需要一种机制来&lt;strong&gt;持久化对话历史&lt;/strong&gt;，并在后续交互中将其重新注入上下文。&lt;/p&gt;
&lt;p&gt;LangChain 结合 LangGraph 提供了强大的 &lt;strong&gt;短期记忆（Short-Term Memory）&lt;/strong&gt; 能力，它不仅支持基础的记忆功能，还能灵活地实现&lt;strong&gt;消息裁剪、主动遗忘、上下文摘要&lt;/strong&gt;等高级策略，从而在保持上下文连贯性的同时，避免超出 LLM 的上下文窗口限制。&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;📌 &lt;strong&gt;核心思想&lt;/strong&gt;：与 LangGraph 中的状态图一样，Agent 的短期记忆依赖于&lt;strong&gt;检查点（Checkpointer）&lt;/strong&gt; 机制。通过将对话历史持久化到内存或数据库中，Agent 可以在中断后恢复会话，实现真正的“有记忆”的交互。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h1&gt;1. 使用短期记忆：记住用户的名字&lt;/h1&gt;
&lt;p&gt;最简单的短期记忆场景，就是让 Agent 记住用户在对话初期提供的信息。&lt;/p&gt;
&lt;p&gt;使用 &lt;code&gt;InMemorySaver&lt;/code&gt; 作为检查点存储（适用于测试），并为会话指定唯一的 &lt;code&gt;thread_id&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
import sys
import os
# 直接指定 src 路径
src_path = &quot;/root/shared-nvme/LLM-Learning/Agent-Learning/src&quot;
sys.path.insert(0, src_path)
from Models import ModelManager

model_manager = ModelManager()
llm = model_manager.get_qwen_model()
checkpointer = InMemorySaver()

agent = create_agent(
    model=llm,
    tools=[],
    checkpointer=checkpointer,
)

config = {&quot;configurable&quot;: {&quot;thread_id&quot;: &quot;session-1&quot;}}

# 第一轮：用户告诉名字
agent.invoke({&quot;messages&quot;: &quot;你好！我叫小李。&quot;}, config)

# 第二轮：随便聊点别的
agent.invoke({&quot;messages&quot;: &quot;今天天气真不错。&quot;}, config)

# 第三轮：问名字，看是否还记得
final = agent.invoke({&quot;messages&quot;: &quot;我刚才告诉你我叫什么名字？&quot;}, config)

print(&quot;\n✅ 最终回答：&quot;)
final[&quot;messages&quot;][-1].pretty_print()
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;说明&lt;/strong&gt;：只要使用了 &lt;code&gt;checkpointer&lt;/code&gt; 并传入相同的 &lt;code&gt;thread_id&lt;/code&gt;，Agent 就会在每次调用时自动加载完整的对话历史作为上下文。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;预期效果&lt;/strong&gt;：即使中间插入无关对话，Agent 仍能准确回忆起“小李”这个名字。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;✅ 最终回答：
================================== Ai Message ==================================

是的，你刚才告诉我你叫小李。我记住了！😊 有什么我可以帮你或和你聊的吗，小李？
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;2. 自动裁剪消息：防止上下文爆炸&lt;/h2&gt;
&lt;p&gt;随着对话轮次增加，消息列表会不断增长。若不加控制，很容易超出 LLM 的上下文长度限制（如 32K tokens），导致性能下降甚至报错。&lt;/p&gt;
&lt;p&gt;LangChain 允许我们通过 &lt;strong&gt;中间件（middleware）&lt;/strong&gt; 在模型调用前动态调整消息历史。下面的例子实现了“保留第一条 + 最近 3 条”的策略：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from typing import Any, Optional
from langchain_core.messages import RemoveMessage
from langgraph.checkpoint.memory import MemorySaver 
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model
from langgraph.graph.message import REMOVE_ALL_MESSAGES


@before_model
def trim_messages(state: AgentState, runtime) -&gt; Optional[dict[str, Any]]:
    &quot;&quot;&quot;保留第一条 + 最近 3 条消息，防止上下文过长&quot;&quot;&quot;
    messages = state[&quot;messages&quot;]
    
    if len(messages) &amp;#x3C;= 4:
        return None  
    
    first_msg = messages[0]
    recent_msgs = messages[-3:]
    new_messages = [first_msg] + recent_msgs

    return {
        &quot;messages&quot;: [
            RemoveMessage(id=REMOVE_ALL_MESSAGES),
            *new_messages
        ]
    }

tools = []

agent = create_agent(
    model=llm,
    tools=tools,
    middleware=[trim_messages],
    checkpointer=MemorySaver(),
)

config = {&quot;configurable&quot;: {&quot;thread_id&quot;: &quot;1&quot;}}

# 测试对话
agent.invoke({&quot;messages&quot;: &quot;你好，我叫小陈。&quot;}, config)
agent.invoke({&quot;messages&quot;: &quot;我全名叫陈小春。&quot;}, config)
agent.invoke({&quot;messages&quot;: &quot;帮我写一首关于春天的诗。&quot;}, config)
agent.invoke({&quot;messages&quot;: &quot;再写一首关于夏天的。&quot;}, config)
agent.invoke({&quot;messages&quot;: &quot;秋天呢？也来一首。&quot;}, config)
final = agent.invoke({&quot;messages&quot;: &quot;我叫什么名字？你知道我的全名吗？只需要回答我的问题，不需要回答其他内容&quot;}, config)

print(&quot;\n✅ 最终回答：&quot;)
final[&quot;messages&quot;][-1].pretty_print()
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;🔧 &lt;strong&gt;原理&lt;/strong&gt;：&lt;code&gt;@before_model&lt;/code&gt; 装饰器定义了一个在每次 LLM 调用前执行的钩子函数。我们在此函数中构造新的消息列表，并通过 &lt;code&gt;RemoveMessage(id=REMOVE_ALL_MESSAGES)&lt;/code&gt; 清空旧历史，再插入精简后的消息。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;预期效果&lt;/strong&gt;：尽管发送了 5 条消息，但 Agent 实际看到的上下文只有 4 条（首条 + 最近 3 条），因此不能记住“陈小春”这个全名。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;✅ 最终回答：
================================== Ai Message ==================================

你叫小陈，我没有你的全名。
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;3. 主动删除消息：强制“遗忘”早期信息&lt;/h2&gt;
&lt;p&gt;有时我们希望 Agent &lt;strong&gt;主动丢弃某些敏感或过时的信息&lt;/strong&gt;。下面的例子演示了如何在每轮对话后&lt;strong&gt;自动删除最早的两条消息&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@before_model
def post_model_hook(state: AgentState, runtime) -&gt; Optional[dict[str, Any]]:
    messages = state[&quot;messages&quot;]
    # 如果消息超过 2 条，就删除最早的 2 条
    if len(messages) &gt; 2:
        return {&quot;messages&quot;: [RemoveMessage(id=m.id) for m in messages[:2]]}
    return {}

agent = create_agent(
    model=llm,
    tools=[],
    middleware=[post_model_hook],
    checkpointer=InMemorySaver(),
)

config = {&quot;configurable&quot;: {&quot;thread_id&quot;: &quot;session-3&quot;}}

# 发送三条消息
agent.invoke({&quot;messages&quot;: &quot;我是张敏。&quot;}, config)
agent.invoke({&quot;messages&quot;: &quot;我今年 28 岁。&quot;}, config)
agent.invoke({&quot;messages&quot;: &quot;你能记住我的年龄吗？&quot;}, config) 

# 再问名字或年龄（应已丢失）
response = agent.invoke({&quot;messages&quot;: &quot;我叫什么名字？&quot;}, config)
print(&quot;\n✅ 删除后回答（应不记得名字）：&quot;)
response[&quot;messages&quot;][-1].pretty_print()
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;⚠️ &lt;strong&gt;注意&lt;/strong&gt;：这里使用的是 &lt;code&gt;@before_model&lt;/code&gt;（实际应在模型调用前处理），但逻辑上模拟了“每轮后清理”的行为。严格来说，若要在&lt;strong&gt;模型响应后&lt;/strong&gt;删除，应使用 &lt;code&gt;@after_model&lt;/code&gt; 或 &lt;code&gt;post_model_hook&lt;/code&gt;。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;&lt;strong&gt;预期效果&lt;/strong&gt;：当第三条消息被处理时，前两条（名字和年龄）已被删除，因此 Agent 无法回答“我叫什么名字”。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;✅ 删除后回答（应不记得名字）：
================================== Ai Message ==================================

你叫“用户”。在我们这次对话中，这是你的默认身份。如果你想告诉我你的名字，我很乐意记住并使用它！
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;LangChain 的短期记忆机制为 Agent 提供了强大而灵活的上下文管理能力：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;✅ &lt;strong&gt;基础记忆&lt;/strong&gt;：通过 &lt;code&gt;checkpointer&lt;/code&gt; 实现跨轮次状态保持；&lt;/li&gt;
&lt;li&gt;✅ &lt;strong&gt;智能裁剪&lt;/strong&gt;：利用中间件动态精简上下文，避免 token 耗尽；&lt;/li&gt;
&lt;li&gt;✅ &lt;strong&gt;主动遗忘&lt;/strong&gt;：按需删除敏感或冗余信息，提升安全性与效率。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这些能力使得 Agent 不仅“聪明”，而且“可靠”——既能记住重要信息，又不会被冗长的历史拖累。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>Agent实战（二）构建智能体的完整实践指南：工具调用、错误处理、动态模型选择与记忆扩展</title><link>https://astro-pure.js.org/blog/agent_blogs/agent_blogs-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/agent_blogs/agent_blogs-2</guid><description>Agent实战</description><pubDate>Wed, 11 Feb 2026 20:40:01 GMT</pubDate><content:encoded>&lt;p&gt;本章节参考&lt;a href=&quot;https://blog.csdn.net/zhangbaolin/article/details/154131591?ops_request_misc=&amp;#x26;request_id=&amp;#x26;biz_id=102&amp;#x26;utm_term=agent%E9%80%9A%E8%BF%87state_schema%E5%AE%9E%E7%8E%B0%E8%AE%B0%E5%BF%86&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-1-154131591.142%5Ev102%5Epc_search_result_base7&amp;#x26;spm=1018.2226.3001.4187&quot;&gt;博客&lt;/a&gt;
代码开源&lt;a href=&quot;https://github.com/SoupCola/Agent_Project&quot;&gt;Github&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;在大模型应用开发中，&lt;strong&gt;智能体（Agent）&lt;/strong&gt; 是连接语言模型与现实世界能力的核心桥梁。它不仅能推理，还能调用工具、处理错误、动态切换模型、记住用户上下文，甚至按需返回结构化数据。本文将通过一系列递进式实验，初步了解 LangChain/LangGraph 中 Agent 的高级用法。&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;1. 基础 Agent 与工具调用&lt;/h1&gt;
&lt;p&gt;首先创建一个基础 Agent，并为其配备三个工具：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;网络搜索（&lt;code&gt;MyWebSearchTool&lt;/code&gt;）&lt;/li&gt;
&lt;li&gt;RAG 检索（&lt;code&gt;RAGTool&lt;/code&gt;）&lt;/li&gt;
&lt;li&gt;一个简单的除法函数 &lt;code&gt;divide&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;model_manager = ModelManager()
llm = model_manager.get_qwen_model()

web_search_tool = MyWebSearchTool()
rag_tool = RAGTool()

@tool(&apos;divide&apos;, parse_docstring=True)
def divide(a: int, b: int) -&gt; int:
    &quot;&quot;&quot;计算两个数值的除法结果。
    
    Args:
        a (int): 被除数，
        b (int): 除数。
    
    Returns:
        int: 计算结果。
    &quot;&quot;&quot;
    return a / b

tools = [web_search_tool, rag_tool, divide]

agent = create_agent(model=llm, tools=tools)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后发起一个涉及计算的问题：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;state = {&quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;湖南省和河南省各有多少人口？它们之间的人口比例是多少？&quot;}]}
for event in agent.invoke(state)[&apos;messages&apos;]:
    event.pretty_print()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;执行我的Web搜索工具，输入的参数为: 湖南省和河南省各有多少人口？它们之间的人口比例是多少？
================================ Human Message =================================

湖南省和河南省各有多少人口？它们之间的人口比例是多少？?
================================== Ai Message ==================================
Tool Calls:
  web_search_tool (call_c2540c1c985b4312b6cb89)
 Call ID: call_c2540c1c985b4312b6cb89
  Args:
    query: 湖南省和河南省各有多少人口？它们之间的人口比例是多少？
================================= Tool Message =================================
Name: web_search_tool

3月30日，河南省统计局网站发布了《2023年河南省国民经济和社会发展统计公报》。至此，2023年中部六省常住人口变动情况悉数公布.....
================================== Ai Message ==================================

根据2023年末的数据：
1. **河南省**的常住人口为 **9815万人**。
2. **湖南省**的常住人口为 **6568万人**。
两者之间的人口比例可以通过计算得出：
$$
\text{人口比例} = \frac{\text{湖南省人口}}{\text{河南省人口}} = \frac{6568}{9815}
$$

接下来我将进行计算，以提供更具体的比例信息。
Tool Calls:
  divide (call_2024a723fd454aecb75514)
 Call ID: call_2024a723fd454aecb75514
  Args:
    a: 6568
    b: 9815
================================= Tool Message =================================
Name: divide

0.6691798267957209
================================== Ai Message ==================================

根据计算结果：

- 湖南省和河南省的人口比例约为 **0.67:1**，即湖南省人口大约是河南省人口的 **67%**。
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h1&gt;2. 自定义工具错误处理&lt;/h1&gt;
&lt;p&gt;工具执行可能失败（如除零错误）。我们可以通过中间件拦截异常并返回友好提示。&lt;/p&gt;
&lt;h2&gt;定义错误处理中间件：&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.agents.middleware import wrap_tool_call
from langchain_core.messages import ToolMessage

@wrap_tool_call
def handle_tool_errors(request, handler):
    try:
        return handler(request)
    except Exception as e:
        return ToolMessage(
            content=f&quot;Tool error: Please check your input and try again. ({str(e)})&quot;,
            tool_call_id=request.tool_call[&quot;id&quot;]
        )
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;模拟错误工具：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@tool(&apos;divide_error&apos;, parse_docstring=True)
def divide_error(a: int, b: int) -&gt; int:
    return 1 / 0  # 故意制造错误
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;创建带错误处理的 Agent：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;tools = [web_search_tool, rag_tool, divide_error]
agent = create_agent(
			model=llm, 
			tools=tools, 
			middleware=[handle_tool_errors]
			)

# 调用
for event in agent.invoke(state)[&apos;messages&apos;]:
    event.pretty_print()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果&lt;/strong&gt;：&lt;br&gt;
当 &lt;code&gt;divide_error&lt;/code&gt; 被调用时，Agent 收到的是我们自定义的错误消息，而非原始异常堆栈。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;执行我的Web搜索工具，输入的参数为: 湖南省和河南省各有多少人口？
================================ Human Message =================================

湖南省和河南省各有多少人口？它们之间的人口比例是多少？?
================================== Ai Message ==================================
Tool Calls:
  web_search_tool (call_02db215cff38499ab56410)
 Call ID: call_02db215cff38499ab56410
  Args:
    query: 湖南省和河南省各有多少人口？
================================= Tool Message =================================
Name: web_search_tool

湖北省与河南省同属于中国华中区域的省份，河南位于华中区域的偏北端，湖北位于华中区域的中心地带，湖南省位于华中区域的偏南处。湖北省现有面积为185900平方公里，现在的人口数量为5844万人...
================================== Ai Message ==================================

根据最新数据：

1. **河南省**：2023年末常住人口为**9815万人**。
2. **湖南省**：2023年末常住人口为**6568万人**。

### 人口比例计算：
我们使用除法来计算两者的人口比例。公式是：

$$
\text{人口比例} = \frac{\text{湖南省人口}}{\text{河南省人口}}
$$

将数值代入公式：

$$
\text{人口比例} = \frac{6568}{9815}
$$

现在我将调用 `divide_error` 工具来计算这个比例。
Tool Calls:
  divide_error (call_7a3bfc5bf40648e8b97121)
 Call ID: call_7a3bfc5bf40648e8b97121
  Args:
    a: 6568
    b: 9815
================================= Tool Message =================================

Tool error: Please check your input and try again. (division by zero)
================================== Ai Message ==================================

看起来出现了计算错误。实际上，我们已经知道河南省人口为9815万人，湖南省人口为6568万人，可以直接进行比例计算：

$$
\text{人口比例} = \frac{6568}{9815} \approx 0.67
$$

这意味着湖南省的人口约为河南省的 **67%**。
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到上面打印的输出，这是我们之前在中间件中定义的输出格式，Tool error: Please check your input and try again.&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;3. 动态模型选择：根据上下文切换 LLM&lt;/h1&gt;
&lt;p&gt;有时我们需要根据用户需求切换底层模型（如 Qwen vs DeepSeek）。&lt;/p&gt;
&lt;h2&gt;定义上下文类：&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from dataclasses import dataclass
from typing import Literal

@dataclass
class CustomContext:
    provider: Literal[&quot;qwen&quot;, &quot;deepseek&quot;]
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;实现模型选择中间件：&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@wrap_model_call
def dynamic_model_selection(request, handler):
    provider = request.runtime.context.provider
    print(f&quot;获取到provider={provider}&quot;)
    request.model = qwen_model if provider == &quot;qwen&quot; else deepseek_model
    return handler(request)
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;调用时指定模型：&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;result = agent.invoke(
    {&quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;你是哪个公司下面的AI？&quot;}]},
    context=CustomContext(provider=&quot;deepseek&quot;)
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;效果&lt;/strong&gt;：&lt;br&gt;
Agent 将使用 DeepSeek 模型回答，而非默认的 Qwen。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;获取到provider=deepseek
执行我的Web搜索工具，输入的参数为: DeepSeek是谁开发的？
获取到provider=deepseek
================================ Human Message =================================

你的老板是谁？
================================== Ai Message ==================================

我会调用搜索工具来查找相关的信息：
Tool Calls:
  web_search_tool (call_77oisnyc5sbeqj0ymjhs0vk4)
 Call ID: call_77oisnyc5sbeqj0ymjhs0vk4
  Args:
    query: DeepSeek是谁开发的？
================================= Tool Message =================================
Name: web_search_tool

DeepSeek4j简介
DeepSeek4j Spring Boot Starter为开发者提供了一种简便的方式来集成DeepSeek AI的强大功能到Spring Boot项目中。通过简单的配置和易于使用的API，即使是AI初学者也能轻松上手。该库支持流式返回、高级对话管理等功...
================================== Ai Message ==================================

DeepSeek是由杭州深度求索人工智能基础技术研究有限公司开发的。该公司于2023年4月由知名私募巨头幻方量化创立，并在2025年1月20日正式发布了高性能AI推理模型**DeepSeek-R1**，实现了国产AI技术的重大突破。

如果您需要更详细的技术文档或开发集成指南，可以随时进一步提问！
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h1&gt;4. 动态系统提示词：个性化交互风格&lt;/h1&gt;
&lt;p&gt;我们可以根据用户角色（如“专家”或“初学者”）动态调整系统提示。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@dynamic_prompt
def user_level_prompt(request: ModelRequest) -&gt; str:
    level = request.runtime.context.get(&quot;level&quot;, &quot;beginner&quot;)
    base = &quot;你是一个精通机器学习方面的专家.&quot;
    if level == &quot;expert&quot;:
        return f&quot;{base} 解释问题时提供更多的细节.&quot;
    else:
        return f&quot;{base} 用小朋友都能听懂的方式表达.&quot;

agent = create_agent(
    model=llm,
    tools=tools,
    middleware=[user_level_prompt],
    context_schema={&quot;level&quot;: str}
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;分别以 &lt;code&gt;&quot;expert&quot;&lt;/code&gt; 和 &lt;code&gt;&quot;beginner&quot;&lt;/code&gt; 调用：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;result = agent.invoke(
    {&quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;解释一下朴素贝叶斯算法&quot;}]},
    context={&quot;level&quot;: &quot;expert&quot;}
) 
result = agent.invoke(
    {&quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;解释一下朴素贝叶斯算法&quot;}]},
    context={&quot;level&quot;: &quot;beginner&quot;}
) 
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期差异&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;专家版：包含公式、术语、推导过程；&lt;/li&gt;
&lt;li&gt;初学者版：使用比喻、生活例子、避免数学。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;[专家版输出]&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;现在的用户level:expert
{&apos;messages&apos;: [HumanMessage(content=&apos;解释一下朴素贝叶斯算法&apos;, additional_kwargs={}, response_metadata={}, id=&apos;8bacbc64-aad3-46dd-a46d-05c91f906a3e&apos;), AIMessage(content=&apos;朴素贝叶斯算法是一种基于贝叶斯定理并假设特征之间相互独立的概率分类方法。它广泛应用于文本分类、垃圾邮件过滤、情感分析等自然语言处理领域。\n\n### 贝叶斯定理\n朴素贝叶斯的核心是**贝叶斯定理**，其数学表达式为：\n$$ P(A|B) = \\frac{P(B|A) \\cdot P(A)}{P(B)} $$\n\n其中：\n- $ P(A|B) $：在已知 B 的情况下 A 发生的概率（后验概率）。\n- $ P(B|A) $：在已知 A 的情况下 B 发生的概率（似然度）。\n- $ P(A) $：A 发生的先验概率。\n- $ P(B) $：B 发生的先验概率。\n\n### 朴素贝叶斯的“朴素”假设\n朴素贝叶斯的关键假设是**特征之间的条件独立性**。也就是说，在给定类别的情况下，所有特征彼此独立。这使得计算更加简单，并且在许多实际应用中表现良好。\n\n### 算法流程\n1. **数据准备**：\n   - 将数据集划分为训练集和测试集。\n   - 对于文本分类问题，通常需要对文本进行分词、去停用词、提取特征等预处理操作。\n\n2. **训练模型**：\n   - 统计每个类别下各个特征出现的频率。\n   - 计算先验概率 $ P(C_k) $，即每个类别的概率。\n   - 计算条件概率 $ P(x_i | C_k) $，即在某个类别下某个特征出现的概率。\n\n3. **预测新样本**：\n   - 对于一个新的样本，计算其属于每个类别的后验概率 $ P(C_k | x_1, x_2, ..., x_n) $。\n   - 根据最大后验概率原则，选择概率最大的类别作为预测结果。\n\n### 常见变种\n朴素贝叶斯有几种常见的变种，适用于不同类型的数据：\n\n1. **多项式朴素贝叶斯（Multinomial Naive Bayes）**：\n   - 适用于离散特征，尤其是文本分类中的词频统计。\n   - 条件概率公式为：\n     $$ P(x_i | C_k) = \\frac{\\text{count}(x_i, C_k) + \\alpha}{\\sum_{j} \\text{count}(x_j, C_k) + \\alpha \\cdot V} $$\n     其中 $ \\alpha $ 是平滑参数，$ V $ 是词汇表大小。\n\n2. **伯努利朴素贝叶斯（Bernoulli Naive Bayes）**：\n   - 适用于二值特征（0 或 1），例如是否包含某个单词。\n   - 条件概率公式为：\n     $$ P(x_i | C_k) = \\begin{cases} \n     p_{i,k} &amp;#x26; \\text{if } x_i = 1 \\\\\n     1 - p_{i,k} &amp;#x26; \\text{if } x_i = 0 \n     \\end{cases} $$\n     其中 $ p_{i,k} $ 是在类别 $ C_k $ 下特征 $ x_i $ 出现的概率。\n\n3. **高斯朴素贝叶斯（Gaussian Naive Bayes）**：\n   - 适用于连续特征，假设特征服从正态分布。\n   - 条件概率公式为：\n     $$ P(x_i | C_k) = \\frac{1}{\\sqrt{2\\pi\\sigma_k^2}} \\exp\\left(-\\frac{(x_i - \\mu_k)^2}{2\\sigma_k^2}\\right) $$\n     其中 $ \\mu_k $ 和 $ \\sigma_k $ 分别是类别 $ C_k $ 下特征 $ x_i $ 的均值和标准差。\n\n### 优缺点\n#### 优点：\n1. **简单高效**：计算复杂度低，适合大规模数据集。\n2. **对小规模数据有效**：即使在小样本情况下也能取得较好的效果。\n3. **对缺失数据不敏感**：能够处理部分缺失的数据。\n\n#### 缺点：\n1. **特征独立性假设过于理想化**：现实世界中特征之间可能存在相关性，这会影响模型的准确性。\n2. **对输入数据的分布敏感**：如果特征不服从假设的分布（如高斯分布），可能会影响模型性能。\n\n### 应用场景\n朴素贝叶斯算法因其简单性和高效性，常用于以下场景：\n- **文本分类**：如新闻分类、垃圾邮件检测、情感分析等。\n- **推荐系统**：根据用户的历史行为预测兴趣。\n- **医学诊断**：根据患者的症状预测疾病类型。\n\n### 总结\n朴素贝叶斯是一种简单但强大的分类算法，尤其适合处理文本数据。虽然它的“朴素”假设在某些情况下可能不太合理，但在实践中往往能取得很好的效果。&apos;, additional_kwargs={&apos;refusal&apos;: None}, response_metadata={&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 1064, &apos;prompt_tokens&apos;: 495, &apos;total_tokens&apos;: 1559, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_provider&apos;: &apos;openai&apos;, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-00ee2b2f-0021-4fb8-be16-ec56012b26f2&apos;, &apos;finish_reason&apos;: &apos;stop&apos;, &apos;logprobs&apos;: None}, id=&apos;lc_run--25c8b7bd-2203-4a58-9a9d-d5e71990ddb1-0&apos;, usage_metadata={&apos;input_tokens&apos;: 495, &apos;output_tokens&apos;: 1064, &apos;total_tokens&apos;: 1559, &apos;input_token_details&apos;: {}, &apos;output_token_details&apos;: {}})]}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;[初学者版输出]&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;现在的用户level:beginner
{&apos;messages&apos;: [HumanMessage(content=&apos;解释一下朴素贝叶斯算法&apos;, additional_kwargs={}, response_metadata={}, id=&apos;609572f6-3404-47f7-8d67-d09ed38c2440&apos;), AIMessage(content=&apos;朴素贝叶斯算法是一种帮助我们做决定的聪明方法。我们可以把它想象成一个猜谜游戏，这个游戏根据我们看到的东西来猜出答案。\n\n比如说，如果我们想要猜一封邮件是不是垃圾邮件，朴素贝叶斯就会看看这封邮件里有哪些词。如果它看到很多像“免费”、“赢取大奖”这样的词，它可能会猜这是一封垃圾邮件。但如果它看到的是“会议”、“报告”这样的词，它可能猜这不是垃圾邮件。\n\n这个算法之所以叫“朴素”，是因为它假设每个词都是独立的，也就是说，它不考虑这些词之间的关系，只是单独地看每一个词。虽然这是一个简单的假设，但它在很多情况下都工作得很好！\n\n所以，总结一下，朴素贝叶斯就是通过观察一些特征（比如邮件里的词），然后根据这些特征出现的概率来做预测的一种方法。&apos;, additional_kwargs={&apos;refusal&apos;: None}, response_metadata={&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 179, &apos;prompt_tokens&apos;: 499, &apos;total_tokens&apos;: 678, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_provider&apos;: &apos;openai&apos;, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-e308bde6-f71c-47ec-b588-54f0e917452a&apos;, &apos;finish_reason&apos;: &apos;stop&apos;, &apos;logprobs&apos;: None}, id=&apos;lc_run--2f33fe73-92e6-4429-bdb0-df695eef9e6b-0&apos;, usage_metadata={&apos;input_tokens&apos;: 499, &apos;output_tokens&apos;: 179, &apos;total_tokens&apos;: 678, &apos;input_token_details&apos;: {}, &apos;output_token_details&apos;: {}})]}
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h1&gt;5. 结构化输出：强制返回 JSON&lt;/h1&gt;
&lt;p&gt;当需要从文本中提取结构化信息时，可指定 &lt;code&gt;response_format&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class Resume(TypedDict):
    name: str
    age: str
    phone: str

agent = create_agent(
    model=llm,
    tools=tools,
    response_format=ToolStrategy(Resume)
)

result = agent.invoke({
    &quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;我叫宋昌，今年35岁，电话13812345678&quot;}]
})
print(result[&quot;structured_response&quot;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;结果&lt;/strong&gt;：&lt;br&gt;
直接获得标准 JSON 对象，无需后处理。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;{&apos;messages&apos;: [HumanMessage(content=&apos;从以下信息中提取Resume信息。我叫宋昌，今年35岁，我的电话号码是13812345678&apos;, additional_kwargs={}, response_metadata={}, id=&apos;2f94d017-40ee-452d-915e-45defed616e0&apos;), AIMessage(content=&apos;&apos;, additional_kwargs={&apos;refusal&apos;: None}, response_metadata={&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 42, &apos;prompt_tokens&apos;: 576, &apos;total_tokens&apos;: 618, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_provider&apos;: &apos;openai&apos;, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-02d310af-c4db-4046-871b-3fe5c7a46e0f&apos;, &apos;finish_reason&apos;: &apos;tool_calls&apos;, &apos;logprobs&apos;: None}, id=&apos;lc_run--fd0db0ef-6c3b-4211-b6e8-746a25ef7c21-0&apos;, tool_calls=[{&apos;name&apos;: &apos;Resume&apos;, &apos;args&apos;: {&apos;age&apos;: &apos;35&apos;, &apos;name&apos;: &apos;宋昌&apos;, &apos;phone&apos;: &apos;13812345678&apos;}, &apos;id&apos;: &apos;call_f850f162d81f41e08875bf&apos;, &apos;type&apos;: &apos;tool_call&apos;}], usage_metadata={&apos;input_tokens&apos;: 576, &apos;output_tokens&apos;: 42, &apos;total_tokens&apos;: 618, &apos;input_token_details&apos;: {}, &apos;output_token_details&apos;: {}}), ToolMessage(content=&quot;Returning structured response: {&apos;name&apos;: &apos;宋昌&apos;, &apos;age&apos;: &apos;35&apos;, &apos;phone&apos;: &apos;13812345678&apos;}&quot;, name=&apos;Resume&apos;, id=&apos;4e5ad143-157b-46a3-a157-0424f87eb4ee&apos;, tool_call_id=&apos;call_f850f162d81f41e08875bf&apos;)], &apos;structured_response&apos;: {&apos;name&apos;: &apos;宋昌&apos;, &apos;age&apos;: &apos;35&apos;, &apos;phone&apos;: &apos;13812345678&apos;}}
{&apos;name&apos;: &apos;宋昌&apos;, &apos;age&apos;: &apos;35&apos;, &apos;phone&apos;: &apos;13812345678&apos;}
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h1&gt;6. 扩展 Agent 记忆：保存自定义状态&lt;/h1&gt;
&lt;p&gt;Agent 默认记住对话历史（&lt;code&gt;messages&lt;/code&gt;），但我们还可以添加自定义字段，如 &lt;code&gt;user_info&lt;/code&gt;。&lt;/p&gt;
&lt;h2&gt;方法一：通过中间件（适合需要干预流程的场景）&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class CustomState(AgentState):
    user_info: dict

class CustomMiddleware(AgentMiddleware):
    state_schema = CustomState
    def before_model(self, state, runtime):
        if state[&quot;user_info&quot;][&quot;title&quot;] == &quot;大将军&quot;:
            return {&quot;messages&quot;: [SystemMessage(content=&quot;将军好！……&quot;), *state[&quot;messages&quot;]]}

agent = create_agent(llm, tools=tools, middleware=[CustomMiddleware()])

result = agent.invoke({
    &quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;我喜欢研究理论&quot;}],
    &quot;user_info&quot;: {&quot;name&quot;: &quot;赵无恤&quot;, &quot;title&quot;: &quot;大将军&quot;}
})
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;当前状态: {&apos;messages&apos;: [HumanMessage(content=&apos;我喜欢研究理论&apos;, additional_kwargs={}, response_metadata={}, id=&apos;3a726091-e837-474a-afe6-14e69e6829e7&apos;)], &apos;user_info&apos;: {&apos;name&apos;: &apos;赵无恤&apos;, &apos;title&apos;: &apos;大将军&apos;}}
【系统提示词】: 你正在与一位军事统帅对话，请使用庄重语气，并在交流前加上“将军好！”的表达。
{&apos;messages&apos;: [HumanMessage(content=&apos;我喜欢研究理论&apos;, additional_kwargs={}, response_metadata={}, id=&apos;3a726091-e837-474a-afe6-14e69e6829e7&apos;), SystemMessage(content=&apos;你正在与一位军事统帅对话，请使用庄重语气，并在交流前加上“将军好！”的表达。&apos;, additional_kwargs={}, response_metadata={}, id=&apos;c2c05fb1-d1ec-4861-9b04-ab856d8898dc&apos;), AIMessage(content=&apos;将军好！理论研究是智慧的源泉，也是制定战略与决策的重要基础。您对哪一方面的理论感兴趣？无论是兵法、政治、哲学还是其他领域，我都很乐意为您提供相关信息或进行深入探讨。&apos;, additional_kwargs={&apos;refusal&apos;: None}, response_metadata={&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 47, &apos;prompt_tokens&apos;: 500, &apos;total_tokens&apos;: 547, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_provider&apos;: &apos;openai&apos;, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-732de714-22e5-4a38-be2c-f73205b148ac&apos;, &apos;finish_reason&apos;: &apos;stop&apos;, &apos;logprobs&apos;: None}, id=&apos;lc_run--2b6b0aff-5d33-496f-81f1-5d3f5f5a4ee3-0&apos;, usage_metadata={&apos;input_tokens&apos;: 500, &apos;output_tokens&apos;: 47, &apos;total_tokens&apos;: 547, &apos;input_token_details&apos;: {}, &apos;output_token_details&apos;: {}})], &apos;user_info&apos;: {&apos;name&apos;: &apos;赵无恤&apos;, &apos;title&apos;: &apos;大将军&apos;}}
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;方法二：通过 &lt;code&gt;state_schema&lt;/code&gt;（简洁，仅扩展数据）&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;agent = create_agent(model=llm, tools=tools, state_schema=CustomState)

# 第一轮
state1= agent.invoke({
    &quot;messages&quot;: [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;我喜欢研究军事理论&quot;}],
    &quot;user_info&quot;: {&quot;name&quot;: &quot;赵无恤&quot;, &quot;title&quot;: &quot;大将军&quot;},
})
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;{&apos;messages&apos;: [HumanMessage(content=&apos;我喜欢研究军事理论&apos;, additional_kwargs={}, response_metadata={}, id=&apos;2044bd7c-e1e4-4bc4-8440-8c1ad9a7690c&apos;), AIMessage(content=&apos;那非常有趣！军事理论涉及战略、战术、武器系统、军队组织等多个方面。你对哪个具体领域感兴趣呢？比如现代战争中的信息战、网络战，还是传统的陆海空三军作战理论？我可以为你提供一些相关的知识或资料。&apos;, additional_kwargs={&apos;refusal&apos;: None}, response_metadata={&apos;token_usage&apos;: {&apos;completion_tokens&apos;: 56, &apos;prompt_tokens&apos;: 475, &apos;total_tokens&apos;: 531, &apos;completion_tokens_details&apos;: None, &apos;prompt_tokens_details&apos;: None}, &apos;model_provider&apos;: &apos;openai&apos;, &apos;model_name&apos;: &apos;Qwen3-32B&apos;, &apos;system_fingerprint&apos;: None, &apos;id&apos;: &apos;chatcmpl-9d7ea05a-5378-4703-9c33-5c2eab5e8374&apos;, &apos;finish_reason&apos;: &apos;stop&apos;, &apos;logprobs&apos;: None}, id=&apos;lc_run--5380d835-ad4f-4812-8b63-d2b2b10d1c84-0&apos;, usage_metadata={&apos;input_tokens&apos;: 475, &apos;output_tokens&apos;: 56, &apos;total_tokens&apos;: 531, &apos;input_token_details&apos;: {}, &apos;output_token_details&apos;: {}})], &apos;user_info&apos;: {&apos;name&apos;: &apos;赵无恤&apos;, &apos;title&apos;: &apos;大将军&apos;}}
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 第二轮：传递完整 state 实现记忆延续
state2 = agent.invoke({
    **state1,
    &quot;messages&quot;: state1[&quot;messages&quot;] + [{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: &quot;推荐几本经典著作？&quot;}]
})
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;执行我的Web搜索工具，输入的参数为: 军事理论经典著作推荐
根据网络搜索的结果，以下是几本经典的军事理论著作推荐给你：

1. **《孙子兵法》**  
   - 作者：孙武（春秋时期）  
   - 内容简介：这是中国古代最著名的军事著作，被誉为“兵学圣典”。全书共十三篇，涵盖了战争的基本原则、战略思想、战术运用以及治军原则等内容。书中提出的“知己知彼，百战不殆”等经典军事思想，至今仍被广泛研究和应用。

2. **《战争论》**  
   - 作者：卡尔·冯·克劳塞维茨（普鲁士）  
   - 内容简介：这部著作是西方近代军事理论的经典之作，提出了“战争无非是政治通过另一种手段的继续”的著名论断。书中详细探讨了战争的性质、战略、战斗、军队组织等方面的内容，对后世军事思想产生了深远影响。

3. **《大战略》**  
   - 作者：约翰·柯林斯（美国）  
   - 内容简介：这本书系统地论述了美国的战略思想和军事战略，涵盖了对外政策、地理、经济和科学技术等问题，对于研究美国的军事战略具有重要参考价值。

4. **《海权论》**  
   - 作者：阿尔弗雷德·塞耶·马汉（美国）  
   - 内容简介：这本书被认为是近代制海权理论的奠基之作。作者通过分析英国在拿破仑时代的海上霸权，强调了强大的海军控制海洋的重要性。

5. **《制空权》**  
   - 作者：朱利奥·杜黑（意大利）  
   - 内容简介：这本书系统地提出了空军建设和作战的理论，认为掌握制空权是现代战争的关键。

6. **《论持久战》**  
   - 作者：毛泽东（中国）  
   - 内容简介：这是毛泽东在抗日战争期间的重要军事著作，批判了“亡国论”和“速胜论”，提出了持久战的战略思想，对中国抗日战争的胜利起到了重要作用。

这些书籍不仅对军事理论有深入的探讨，也对政治、经济、外交等多个领域有重要的借鉴意义。希望你能从中获得启发！如果你对某本书特别感兴趣，我可以为你提供更详细的介绍或相关内容。
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;✅ 这样，Agent 能够记得对话，实现真正个性化服务。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h1&gt;总结&lt;/h1&gt;
&lt;p&gt;本文通过六个递进式实验，系统性地展示了 LangChain/LangGraph 框架下构建高阶 Agent 的关键技术路径。以下从技术架构层面进行归纳：&lt;/p&gt;
&lt;p&gt;| 能力维度 | 技术实现机制 | 核心价值 | 典型应用场景 |
|----------|----------------|----------|----------------|
| &lt;strong&gt;工具集成与推理链编排&lt;/strong&gt; | 通过 &lt;code&gt;@tool&lt;/code&gt; 装饰器注册函数或类工具，结合 ReAct 或 Plan-and-Execute 等策略，由 LLM 动态决定工具调用顺序与参数。Agent 在单次推理中可完成“检索 → 计算 → 总结”的多步协同。 | 实现 LLM 与外部能力（API、数据库、计算函数）的语义桥接，突破纯文本生成的局限。 | 多跳问答、数据驱动决策、自动化报告生成等需组合多种能力的任务。 |
| &lt;strong&gt;工具执行异常的中间件拦截&lt;/strong&gt; | 利用 &lt;code&gt;wrap_tool_call&lt;/code&gt; 中间件对工具调用进行包装，在 &lt;code&gt;handler(request)&lt;/code&gt; 执行前后插入异常捕获逻辑，并返回符合 ToolMessage 协议的结构化错误响应。 | 将底层运行时异常转化为 Agent 可理解、可恢复的语义信号，避免流程中断，提升鲁棒性。 | 生产环境中对不可靠工具（如第三方 API、用户自定义函数）的安全封装。 |
| &lt;strong&gt;上下文感知的动态模型路由&lt;/strong&gt; | 通过自定义 &lt;code&gt;Context&lt;/code&gt; 类携带元信息（如 &lt;code&gt;provider&lt;/code&gt;），配合 &lt;code&gt;wrap_model_call&lt;/code&gt; 中间件在请求分发前动态替换 &lt;code&gt;request.model&lt;/code&gt; 字段，实现运行时模型切换。 | 解耦 Agent 逻辑与具体 LLM 实现，支持按任务类型、成本、性能或合规要求动态调度异构模型。 | 多租户 SaaS 应用、混合模型部署（如 Qwen 用于通用对话，DeepSeek 用于代码生成）。 |
| &lt;strong&gt;个性化系统提示（Dynamic Prompting）&lt;/strong&gt; | 基于运行时上下文（如用户角色 &lt;code&gt;level&lt;/code&gt;）动态生成 system prompt，通过 &lt;code&gt;@dynamic_prompt&lt;/code&gt; 中间件注入到模型输入前缀中，引导 LLM 输出风格与深度。 | 实现细粒度的输出控制，无需训练多个专用模型即可适配不同用户认知水平或业务语境。 | 教育产品（初学者/专家模式）、客服系统（普通用户/管理员视图）、医疗问答（患者/医生视角）。 |
| &lt;strong&gt;强约束结构化输出（Structured Output）&lt;/strong&gt; | 利用 &lt;code&gt;response_format=ToolStrategy(schema)&lt;/code&gt; 指令，强制 LLM 以工具调用形式返回符合 Pydantic/TypedDict Schema 的 JSON 对象，绕过自由文本解析。 | 消除后处理环节的不确定性，确保输出可直接用于下游系统（如数据库写入、API 响应），提升端到端可靠性。 | 信息抽取（简历、票据、日志）、表单填充、API 参数生成等需高精度结构化结果的场景。 |
| &lt;strong&gt;可扩展的状态管理（Custom State Schema）&lt;/strong&gt; | 通过继承 &lt;code&gt;AgentState&lt;/code&gt; 定义包含 &lt;code&gt;user_info&lt;/code&gt; 等字段的自定义状态类型，并在 &lt;code&gt;state_schema&lt;/code&gt; 中声明；中间件或后续步骤可读写该状态，实现跨轮次上下文持久化。 | 超越仅维护消息历史的限制，支持业务相关上下文（如用户身份、会话偏好、任务进度）的端到端传递与更新。 | 多轮复杂对话（如订票、诊断）、个性化推荐、带状态的工作流自动化。 |&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;架构启示&lt;/strong&gt;：上述能力均基于 LangGraph 提供的&lt;strong&gt;可组合中间件（Middleware）机制&lt;/strong&gt;与&lt;strong&gt;显式状态图（State Graph）抽象&lt;/strong&gt;。这种设计将传统 Agent 的“黑盒推理”转化为&lt;strong&gt;可观测、可干预、可扩展的执行流水线&lt;/strong&gt;，为构建企业级可靠 AI 应用奠定了工程基础。&lt;/p&gt;
&lt;/blockquote&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>Agent实战（一）Agent、LangChain、LangGraph和LangSmith</title><link>https://astro-pure.js.org/blog/agent_blogs/agent_blogs-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/agent_blogs/agent_blogs-1</guid><description>Agent实战</description><pubDate>Wed, 11 Feb 2026 20:40:00 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/Agent_Project&quot;&gt;Github&lt;/a&gt;&lt;/p&gt;
&lt;h1&gt;🤖 什么是 Agent（智能体）？&lt;/h1&gt;
&lt;p&gt;&lt;strong&gt;Agent 是一个能“自己思考并行动”的 AI 程序。&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;它不只是被动地回答问题，而是：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;听懂你的目标&lt;/strong&gt;（比如“查上个月销量”）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;自己规划怎么做&lt;/strong&gt;（比如：先写 SQL → 查数据库 → 整理结果）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;调用工具完成任务&lt;/strong&gt;（如访问数据库、搜索网页、运行代码）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;把结果告诉你&lt;/strong&gt;（用自然语言总结）&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 简单说：&lt;strong&gt;Agent = 能自动做事的 AI 助手&lt;/strong&gt;，不是只能聊天的“问答机”。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;基于您提供的图片内容，以下是关于 LangChain 中 Agent 及其组成的简单介绍：&lt;/p&gt;
&lt;p&gt;Agent 是一个能够通过循环调用工具来完成复杂目标的智能系统。它不像传统程序那样按固定流程执行，而是更像一位智能项目经理，能够自主规划、决策和执行任务。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Agent 的工作流程&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;输入&lt;/strong&gt;：接收用户请求&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;模型推理&lt;/strong&gt;：LLM 作为&quot;大脑&quot;分析任务，决定下一步行动&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;循环执行&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;行动&lt;/strong&gt;：选择并调用合适的工具&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;观察&lt;/strong&gt;：获取工具执行结果&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;输出&lt;/strong&gt;：满足停止条件后，生成最终结果&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;Agent 的四大核心组件&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;| 组件 | 角色 | 功能描述 |
|------|------|----------|
| &lt;strong&gt;模型(Model)&lt;/strong&gt; | 大脑 | 负责推理和决策，支持多种LLM提供商，基于思维链模式分解复杂问题 |
| &lt;strong&gt;工具(Tools)&lt;/strong&gt; | 手脚 | Agent的能力扩展接口，提供300+预置功能（搜索、计算、数据库操作等） |
| &lt;strong&gt;记忆(Memory)&lt;/strong&gt; | 记忆库 | 维护短期对话上下文和长期知识存储，支持跨轮次协作 |
| &lt;strong&gt;AgentExecutor&lt;/strong&gt; | 协调器 | 控制执行流程，处理循环、异常，确保任务顺利完成 |&lt;/p&gt;
&lt;h3&gt;Agent 的三大核心能力&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;动态任务路由&lt;/strong&gt;：根据输入内容智能规划执行路径&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;生态化工具集成&lt;/strong&gt;：访问丰富的工具接口覆盖多领域需求&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;全周期记忆管理&lt;/strong&gt;：同时维护短期上下文和长期知识存储&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;这种架构使 Agent 能够像人类一样思考-行动-观察-再思考，直到完美解决复杂问题。&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;AIGC 与 AI Agent 特性对比&lt;/h1&gt;
&lt;p&gt;| 特性 | AIGC（如 ChatGPT） | AI Agent（如 Manus, Operator） |
|------|-------------------|-------------------------------|
| &lt;strong&gt;核心能力&lt;/strong&gt; | 内容生成 | 任务规划与自主执行 |
| &lt;strong&gt;交互模式&lt;/strong&gt; | 被动响应，依赖提示词 | 主动规划，自主决策 |
| &lt;strong&gt;输出结果&lt;/strong&gt; | 建议、方案、内容（需人工后续处理） | 可交付的最终成果（如已发送的邮件、整理好的报表） |
| &lt;strong&gt;与外界交互&lt;/strong&gt; | 有限，主要通过文本 | 强大，可通过工具调用操作现实世界系统 |&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;LangChain与LangGraph&lt;/h1&gt;
&lt;h2&gt;LangChain&lt;/h2&gt;
&lt;p&gt;LangChain 是一个用于构建大语言模型（LLM）应用程序的框架。它的核心目标是简化AI应用的开发流程，通过提供一系列标准化的模块（如模型调用、记忆管理、数据检索和工具调用），让开发者能够像搭积木一样，快速组合出功能丰富的智能体（Agent）。使用 LangChain，您可以轻松地为LLM添加长期记忆、连接外部知识库、以及赋予其执行代码或操作API的能力，从而快速构建出问答系统、摘要工具或内容生成器等通用AI应用。它主要解决了LLM应用的“最后一公里”问题，极大地提升了开发效率。&lt;/p&gt;
&lt;h2&gt;LangGraph&lt;/h2&gt;
&lt;p&gt;LangGraph 是建立在 LangChain 之上的一个库，专注于为智能体提供复杂、可定制的控制流。如果说 LangChain 负责组装智能体的“零部件”，那么 LangGraph 则负责设计其“大脑”的决策流程。它允许开发者以图（Graph）的形式来定义智能体的工作流，例如实现多步骤规划、循环执行、条件分支以及让多个“工作者”智能体协同完成任务。这使得 LangGraph 特别适合构建需要长期运行、状态持久化、行为复杂且可控的自主智能体，例如模拟角色、复杂的决策支持系统或自动化工作流引擎。&lt;/p&gt;
&lt;h2&gt;对比总结&lt;/h2&gt;
&lt;p&gt;| 特性 | LangChain | LangGraph |
| :--- | :--- | :--- |
| &lt;strong&gt;核心定位&lt;/strong&gt; | 快速构建功能完善的LLM应用框架 | 为智能体设计复杂控制流的运行时 |
| &lt;strong&gt;主要目标&lt;/strong&gt; | &lt;strong&gt;开发效率&lt;/strong&gt;与&lt;strong&gt;标准化&lt;/strong&gt;，提供开箱即用的组件 | &lt;strong&gt;精细控制&lt;/strong&gt;与&lt;strong&gt;灵活性&lt;/strong&gt;，支持定制化工作流 |
| &lt;strong&gt;擅长场景&lt;/strong&gt; | 通用AI应用（如聊天机器人、文档问答） | 复杂、有状态的自主智能体（如模拟、多步推理） |
| &lt;strong&gt;类比&lt;/strong&gt; | 提供了一套完整的&lt;strong&gt;乐高积木&lt;/strong&gt;和搭建手册 | 提供了设计和连接复杂&lt;strong&gt;齿轮与传动系统&lt;/strong&gt;的工具 |
| &lt;strong&gt;关系&lt;/strong&gt; | 提供了智能体的核心组件（工具、记忆等） | 基于 LangChain 的组件，为其增添复杂逻辑和流程控制 |&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;LangSmith&lt;/h1&gt;
&lt;p&gt;LangSmith 是由 LangChain 团队开发的一站式 &lt;strong&gt;LLM（大语言模型）应用开发、调试、测试与监控平台&lt;/strong&gt;。它的目标是帮助开发者更高效地构建、评估和优化基于 LLM 的应用程序，尤其适用于使用 LangChain 或 LangGraph 构建的复杂智能体（Agent）系统。&lt;/p&gt;
&lt;h2&gt;一、核心功能概览&lt;/h2&gt;
&lt;h3&gt;1. &lt;strong&gt;Trace 可视化（执行追踪）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;自动记录每一次 LLM 调用、工具调用、链（Chain）执行、状态变更等。&lt;/li&gt;
&lt;li&gt;以&lt;strong&gt;时间线 + 树状结构&lt;/strong&gt;展示完整执行流程，清晰看到：
&lt;ul&gt;
&lt;li&gt;Prompt 输入&lt;/li&gt;
&lt;li&gt;LLM 输出&lt;/li&gt;
&lt;li&gt;工具调用参数与结果&lt;/li&gt;
&lt;li&gt;内部状态变化（尤其在 LangGraph 中）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;支持查看 token 使用量、延迟、模型名称等元数据。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;✅ 类似于“Chrome DevTools for LLM apps”。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;2. &lt;strong&gt;数据集管理（Datasets）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;创建和管理测试/评估数据集（输入 + 期望输出）。&lt;/li&gt;
&lt;li&gt;支持从真实用户交互中采样生成数据集（通过“Feedback”或手动标注）。&lt;/li&gt;
&lt;li&gt;可用于回归测试，确保新版本不会降低性能。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;3. &lt;strong&gt;自动评估（Evaluators）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;提供内置评估器（如：准确性、相关性、有害性、忠实度等）。&lt;/li&gt;
&lt;li&gt;支持自定义 Python 评估函数（可调用 LLM 进行评判）。&lt;/li&gt;
&lt;li&gt;批量运行评估：将整个数据集在你的应用上跑一遍，生成量化指标（如准确率 85%）。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;📊 示例：评估客服机器人回答是否“有帮助”且“无幻觉”。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3&gt;4. &lt;strong&gt;人工反馈与标注（Human Feedback）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;允许人工对模型输出打分（👍/👎）或添加评论。&lt;/li&gt;
&lt;li&gt;反馈数据可回流到训练或微调 pipeline（结合外部工具）。&lt;/li&gt;
&lt;li&gt;支持 A/B 测试不同 prompt 或模型版本的效果。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;5. &lt;strong&gt;监控与告警（Monitoring）&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;实时监控生产环境中的 LLM 应用表现。&lt;/li&gt;
&lt;li&gt;跟踪错误率、延迟、token 消耗、成本等关键指标。&lt;/li&gt;
&lt;li&gt;可设置告警规则（如“当幻觉率 &gt; 10% 时通知 Slack”）。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;6. &lt;strong&gt;无缝集成 LangChain / LangGraph&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;只需设置一个 API key，LangChain 应用自动上报 trace 到 LangSmith。&lt;/li&gt;
&lt;li&gt;对 LangGraph 的状态图（StateGraph）支持极佳，能清晰展示每一步的状态变迁。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
os.environ[&quot;LANGCHAIN_TRACING_V2&quot;] = &quot;true&quot;
os.environ[&quot;LANGCHAIN_API_KEY&quot;] = &quot;your-api-key&quot;
# 之后所有 LangChain/LangGraph 运行都会自动记录到 LangSmith
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;二、典型使用场景&lt;/h2&gt;
&lt;p&gt;| 场景 | LangSmith 如何帮助 |
|------|------------------|
| &lt;strong&gt;调试复杂 Agent&lt;/strong&gt; | 查看哪一步工具调用失败，LLM 是否误解了指令 |
| &lt;strong&gt;优化 Prompt&lt;/strong&gt; | 对比不同 prompt 版本在相同数据集上的表现 |
| &lt;strong&gt;上线前测试&lt;/strong&gt; | 用历史对话数据做回归测试，确保不退化 |
| &lt;strong&gt;生产监控&lt;/strong&gt; | 发现某天用户满意度骤降，追溯是模型还是数据问题 |
| &lt;strong&gt;团队协作&lt;/strong&gt; | 共享 trace 和数据集，产品/算法/工程师对齐问题 |&lt;/p&gt;
&lt;h2&gt;三、一句话总结&lt;/h2&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;LangSmith 是 LLM 应用的“全生命周期管理平台”——从开发调试、自动化测试到生产监控，一站式搞定。&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;如果你正在用 LangChain 或 LangGraph 构建智能体、RAG 系统或多轮对话机器人，&lt;strong&gt;强烈建议接入 LangSmith&lt;/strong&gt;，它能极大提升迭代效率和系统可靠性。&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;创建Agent项目本地测试环境&lt;/h1&gt;
&lt;p&gt;LangChain 为我们提供了一个集成了 Studio 和 LangSmith 的测试环境，使得开发者可以在本地对 Agent 进行可视化、交互式调试和优化。其中，Studio 是一个专门面向 Agent 的集成开发环境（IDE），支持对基于智能体服务器 API 协议的代理系统进行可视化操作、实时交互和调试。它还集成了跟踪（追踪执行流程）、评估（自动或人工评价输出质量）以及提示工程（Prompt Engineering）等功能，帮助用户更高效地构建和改进智能体应用。同时，通过与 LangSmith 的深度集成，可以实现完整的应用监控、性能分析和持续优化。&lt;/p&gt;
&lt;h2&gt;1. 安装LangGraph CLI&lt;/h2&gt;
&lt;p&gt;要安装 LangGraph 的命令行工具（CLI），首先需要确保 Python 版本不低于 3.11。然后，通过 pip 命令执行 &lt;code&gt;pip install --upgrade &quot;langgraph-cli[inmem]&quot;&lt;/code&gt; 来安装或升级 LangGraph CLI，并启用其内存存储后端（in-memory backend），以便在本地快速测试和运行基于 LangGraph 构建的有状态工作流应用。&lt;/p&gt;
&lt;h2&gt;2. 配置LangSmith的环境变量&lt;/h2&gt;
&lt;p&gt;要在项目中配置 LangSmith 的 API 密钥，请按以下步骤操作：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;获取 API 密钥&lt;/strong&gt;&lt;br&gt;
登录或注册 LangSmith 平台：&lt;a href=&quot;https://smith.langchain.com&quot;&gt;https://smith.langchain.com&lt;/a&gt;。&lt;br&gt;
进入后，点击右上角头像 → 选择 &lt;strong&gt;“Settings”（设置）&lt;/strong&gt; → 在 &lt;strong&gt;“API Keys”&lt;/strong&gt; 页面中创建或复制您的 API 密钥（格式通常以 &lt;code&gt;lsv2_&lt;/code&gt; 开头）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;创建 &lt;code&gt;.env&lt;/code&gt; 文件&lt;/strong&gt;&lt;br&gt;
在您项目的根目录下新建一个名为 &lt;code&gt;.env&lt;/code&gt; 的文件（注意前面有一个点）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;填写环境变量&lt;/strong&gt;&lt;br&gt;
在 &lt;code&gt;.env&lt;/code&gt; 文件中添加如下内容，将 &lt;code&gt;&amp;#x3C;your-api-key&gt;&lt;/code&gt; 替换为您从 LangSmith 复制的实际密钥：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;LANGSMITH_API_KEY=lsv2_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;启用追踪功能（可选但推荐）&lt;/strong&gt;&lt;br&gt;
确保在代码或环境变量中开启 LangSmith 追踪：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
os.environ[&quot;LANGCHAIN_TRACING_V2&quot;] = &quot;true&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;或在 &lt;code&gt;.env&lt;/code&gt; 文件中一并添加：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;LANGCHAIN_TRACING_V2=true
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;完成以上配置后，只要您的应用使用了 LangChain 或 LangGraph，所有运行过程（如 Chain 调用、Agent 步骤、状态变更等）都会自动上报到 LangSmith 平台，您可以在 &lt;a href=&quot;https://smith.langchain.com&quot;&gt;https://smith.langchain.com&lt;/a&gt; 上实时查看执行轨迹、调试问题并进行评估优化。&lt;/p&gt;
&lt;h2&gt;3. 创建LangGraph配置文件&lt;/h2&gt;
&lt;p&gt;在应用程序的根目录中，创建一个名为 &lt;code&gt;langgraph.json&lt;/code&gt; 的配置文件，用于定义 LangGraph 项目的结构和依赖关系。该文件以 JSON 格式编写，包含以下关键字段：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;&quot;dependencies&quot;&lt;/code&gt;：指定项目依赖的模块或路径，此处设置为 &lt;code&gt;[&quot;.&quot;]&lt;/code&gt;，表示当前目录下的所有内容。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;&quot;graphs&quot;&lt;/code&gt;：定义要运行的图（Graph）及其对应的实现路径。示例中配置了一个名为 &lt;code&gt;&quot;agent&quot;&lt;/code&gt; 的图，其代码位于 &lt;code&gt;./src/agent.py&lt;/code&gt; 文件中的 &lt;code&gt;agent&lt;/code&gt; 对象。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;&quot;env&quot;&lt;/code&gt;：指定环境变量文件的路径，此处指向 &lt;code&gt;.env&lt;/code&gt; 文件，用于加载 API 密钥等配置信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;完整的配置文件内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-json&quot;&gt;{
  &quot;dependencies&quot;: [&quot;.&quot;],
  &quot;graphs&quot;: {
    &quot;agent&quot;: &quot;./src/agent.py:agent&quot;
  },
  &quot;env&quot;: &quot;.env&quot;
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;此配置文件是使用 LangGraph CLI 进行本地开发、调试和运行 Agent 的必要组成部分，它让工具能够正确加载应用逻辑、环境变量并执行状态图流程。&lt;/p&gt;
&lt;p&gt;创建好的项目结构如下：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fagent_blogs%2Fagent_blogs-1%2F1.png%3ForigWidth%3D242%26origHeight%3D262%26origFormat%3Dpng&amp;#x26;w=242&amp;#x26;h=262&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;4. 编写智能体项目代码&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;@tool
def send_email(to: str, subject: str, body: str):
    &quot;&quot;&quot;发送邮件&quot;&quot;&quot;
    email = {
        &quot;to&quot;: to,
        &quot;subject&quot;: subject,
        &quot;body&quot;: body
    }
    # ...邮件发送逻辑

    return f&quot;邮件已发送至 {to}&quot;


# 创建 React Agent
agent_executor = create_agent(
    model=llm,
    tools=[send_email],
    system_prompt=&quot;你是一个邮件助手。当用户需要发送邮件时，请使用 send_email 工具。&quot;
)
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;5. 安装依赖项&lt;/h2&gt;
&lt;p&gt;在新建的 LangGraph 应用项目的根目录中，需要安装项目依赖项。执行以下命令：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;pip install -e .
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;该命令的作用是：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;-e&lt;/code&gt; 表示“editable”（可编辑模式），即以开发模式安装当前目录下的 Python 包。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;.&lt;/code&gt; 表示当前目录，意味着会读取项目根目录中的 &lt;code&gt;setup.py&lt;/code&gt; 或 &lt;code&gt;pyproject.toml&lt;/code&gt; 文件来识别并安装所需的依赖库。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;使用这种方式安装后，你可以在不重新安装的情况下直接修改源代码，改动会立即生效，非常适合本地开发和调试。此步骤通常在创建 LangGraph 项目并准备运行或测试时执行。&lt;/p&gt;
&lt;h2&gt;6. 在Studio中查看代理&lt;/h2&gt;
&lt;p&gt;在完成前面的配置和安装步骤后，您可以在 Studio 中查看并调试您的智能体（Agent）应用。具体操作如下：&lt;/p&gt;
&lt;p&gt;启动本地的 Agent 服务器，只需在终端中运行以下命令：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;langgraph dev
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;该命令会启动一个本地开发服务器，并自动将您的 LangGraph 应用暴露给 LangChain Studio。&lt;/p&gt;
&lt;p&gt;随后，您的智能体程序将通过 Studio 的用户界面进行访问，访问地址为：&lt;/p&gt;
&lt;p&gt;👉 &lt;strong&gt;&lt;code&gt;http://127.0.0.1:2024&lt;/code&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;通过这个链接，您可以进入 Studio UI，实现对 Agent 的可视化交互、实时调试、状态追踪和流程监控等功能，从而更高效地开发和优化您的 LLM 应用。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Tips：推荐使用隔离环境，或者自己创建虚拟环境都可以。Python 内置的 &lt;code&gt;venv&lt;/code&gt; 模块可以用来创建隔离的虚拟环境。下面是创建隔离环境的教程：&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;✅ 步骤如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 1. 进入你的项目根目录
cd /path/to/your/agent-project

# 2. 创建虚拟环境（建议命名为 .venv 或 venv）
python -m venv .venv

# 3. 激活虚拟环境
source .venv/bin/activate
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;激活后，你会看到命令行前缀出现 &lt;code&gt;(.venv)&lt;/code&gt;，表示已进入虚拟环境。&lt;/p&gt;
&lt;h4&gt;✅ 安装项目依赖&lt;/h4&gt;
&lt;p&gt;&lt;code&gt;pyproject.toml&lt;/code&gt; 是来自官方的，虽然 &lt;code&gt;pyproject.toml&lt;/code&gt; 不会自动创建虚拟环境，但它&lt;strong&gt;定义了项目依赖&lt;/strong&gt;。你可以用以下命令安装：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 安装项目本身及其依赖（以可编辑模式）
pip install -e .
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;code&gt;-e .&lt;/code&gt; 表示“editable install”，即从当前目录安装，并读取 &lt;code&gt;pyproject.toml&lt;/code&gt; 中 &lt;code&gt;[project]&lt;/code&gt; 和 &lt;code&gt;[project.optional-dependencies]&lt;/code&gt; 的配置。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;这样就能确保你的 LangGraph + LangChain 项目在一个干净、可复现的环境中运行。&lt;/p&gt;
&lt;p&gt;最后展示一下&lt;code&gt;LangSmith Studio&lt;/code&gt;运行起来的界面：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fagent_blogs%2Fagent_blogs-1%2F2.png%3ForigWidth%3D1920%26origHeight%3D911%26origFormat%3Dpng&amp;#x26;w=1920&amp;#x26;h=911&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/agent_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>ROS2 手眼相机 6D 位姿调试方案总结</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-11</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-11</guid><description>基于ROS2的机械臂手眼相机位姿调试</description><pubDate>Mon, 09 Feb 2026 12:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;方案概述&lt;/h2&gt;
&lt;p&gt;本方案通过 xacro 参数 + launch 文件 的方式，实现了无需修改 URDF 文件即可实时调整相机 6D 位姿（位置 + 姿态）的便捷调试方法。&lt;/p&gt;
&lt;h2&gt;核心文件结构&lt;/h2&gt;
&lt;pre&gt;&lt;code&gt;ur5e_gripper_moveit_config/
├── urdf/
│   ├── single_ur5e_gripper_handeye.urdf.xacro    # 宏定义（定义相机）
│   └── ur5e_gripper_handeye.urdf.xacro            # 实例文件（声明参数）
└── launch/
    └── view_handeye_robot_adjustable.launch.py    # 可调参数 launch 文件
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;实现步骤&lt;/h2&gt;
&lt;h3&gt;步骤 1: 在宏定义中添加相机位姿参数&lt;/h3&gt;
&lt;p&gt;文件: &lt;code&gt;single_ur5e_gripper_handeye.urdf.xacro&lt;/code&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-xml&quot;&gt;&amp;#x3C;!-- 宏定义添加相机位姿参数 --&gt;
&amp;#x3C;xacro:macro name=&quot;ur5e_gripper_handeye&quot;
             params=&quot;name prefix parent *origin initial_positions_file
                     camera_x_offset camera_y_offset camera_z_offset
                     camera_roll camera_pitch camera_yaw&quot;&gt;

    &amp;#x3C;!-- ... 机械臂和夹爪定义 ... --&gt;

    &amp;#x3C;!-- 手眼相机挂载（使用参数） --&gt;
    &amp;#x3C;xacro:sensor_d435 parent=&quot;${prefix}tool0&quot;
                       name=&quot;${prefix}hand_eye_camera&quot;
                       use_nominal_extrinsics=&quot;true&quot;
                       add_plug=&quot;false&quot;
                       use_mesh=&quot;true&quot;
                       topics_ns=&quot;${prefix}hand_eye_camera&quot;&gt;
        &amp;#x3C;origin xyz=&quot;${camera_x_offset} ${camera_y_offset} ${camera_z_offset}&quot;
                 rpy=&quot;${camera_roll} ${camera_pitch} ${camera_yaw}&quot;/&gt;
    &amp;#x3C;/xacro:sensor_d435&gt;
&amp;#x3C;/xacro:macro&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;步骤 2: 在实例文件中声明 xacro 参数&lt;/h3&gt;
&lt;p&gt;文件: &lt;code&gt;ur5e_gripper_handeye.urdf.xacro&lt;/code&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-xml&quot;&gt;&amp;#x3C;robot xmlns:xacro=&quot;http://www.ros.org/wiki/xacro&quot; name=&quot;ur5e_gripper_handeye&quot;&gt;

    &amp;#x3C;!-- 声明 xacro 参数（设置默认值） --&gt;
    &amp;#x3C;xacro:arg name=&quot;camera_x_offset&quot; default=&quot;0.0&quot;/&gt;
    &amp;#x3C;xacro:arg name=&quot;camera_y_offset&quot; default=&quot;-0.04&quot;/&gt;
    &amp;#x3C;xacro:arg name=&quot;camera_z_offset&quot; default=&quot;-0.03&quot;/&gt;
    &amp;#x3C;xacro:arg name=&quot;camera_roll&quot; default=&quot;0.0&quot;/&gt;
    &amp;#x3C;xacro:arg name=&quot;camera_pitch&quot; default=&quot;-1.5708&quot;/&gt;
    &amp;#x3C;xacro:arg name=&quot;camera_yaw&quot; default=&quot;1.5708&quot;/&gt;

    &amp;#x3C;!-- 调用宏并传递参数 --&gt;
    &amp;#x3C;xacro:ur5e_gripper_handeye name=&quot;ur&quot; prefix=&quot;&quot; parent=&quot;world&quot;
                                initial_positions_file=&quot;$(arg initial_positions_file)&quot;
                                camera_x_offset=&quot;$(arg camera_x_offset)&quot;
                                camera_y_offset=&quot;$(arg camera_y_offset)&quot;
                                camera_z_offset=&quot;$(arg camera_z_offset)&quot;
                                camera_roll=&quot;$(arg camera_roll)&quot;
                                camera_pitch=&quot;$(arg camera_pitch)&quot;
                                camera_yaw=&quot;$(arg camera_yaw)&quot;&gt;
        &amp;#x3C;origin xyz=&quot;0 0 0&quot; rpy=&quot;0 0 0&quot; /&gt;
    &amp;#x3C;/xacro:ur5e_gripper_handeye&gt;
&amp;#x3C;/robot&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;步骤 3: 创建可调参数的 launch 文件&lt;/h3&gt;
&lt;p&gt;文件: &lt;code&gt;view_handeye_robot_adjustable.launch.py&lt;/code&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    # 声明相机位姿参数
    args = []

    # 位置参数
    args.append(DeclareLaunchArgument(
        name=&quot;camera_x_offset&quot;,
        default_value=&quot;0.0&quot;,
        description=&quot;Camera X offset from tool0 (meters)&quot;
    ))
    args.append(DeclareLaunchArgument(
        name=&quot;camera_y_offset&quot;,
        default_value=&quot;-0.04&quot;,
        description=&quot;Camera Y offset from tool0 (meters)&quot;
    ))
    args.append(DeclareLaunchArgument(
        name=&quot;camera_z_offset&quot;,
        default_value=&quot;-0.03&quot;,
        description=&quot;Camera Z offset from tool0 (meters)&quot;
    ))

    # 姿态参数
    args.append(DeclareLaunchArgument(
        name=&quot;camera_roll&quot;,
        default_value=&quot;0.0&quot;,
        description=&quot;Camera roll rotation (radians)&quot;
    ))
    args.append(DeclareLaunchArgument(
        name=&quot;camera_pitch&quot;,
        default_value=&quot;-1.5708&quot;,
        description=&quot;Camera pitch rotation (radians)&quot;
    ))
    args.append(DeclareLaunchArgument(
        name=&quot;camera_yaw&quot;,
        default_value=&quot;1.5708&quot;,
        description=&quot;Camera yaw rotation (radians)&quot;
    ))

    # 动态生成 URDF（传递参数到 xacro）
    robot_description_content = Command([
        &quot;xacro&quot;,
        &quot; &quot;,
        os.path.join(pkg_share, &quot;urdf&quot;, &quot;ur5e_gripper_handeye.urdf.xacro&quot;),
        &quot; camera_x_offset:=&quot;, LaunchConfiguration(&quot;camera_x_offset&quot;),
        &quot; camera_y_offset:=&quot;, LaunchConfiguration(&quot;camera_y_offset&quot;),
        &quot; camera_z_offset:=&quot;, LaunchConfiguration(&quot;camera_z_offset&quot;),
        &quot; camera_roll:=&quot;, LaunchConfiguration(&quot;camera_roll&quot;),
        &quot; camera_pitch:=&quot;, LaunchConfiguration(&quot;camera_pitch&quot;),
        &quot; camera_yaw:=&quot;, LaunchConfiguration(&quot;camera_yaw&quot;),
    ])

    # robot_state_publisher、joint_state_publisher_gui、rviz2 节点...
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;使用方法&lt;/h2&gt;
&lt;h3&gt;基础使用（默认参数）&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;ros2 launch ur5e_gripper_moveit_config view_handeye_robot_adjustable.launch.py
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;调整位置&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 调整 X 方向偏移 5cm
ros2 launch ur5e_gripper_moveit_config view_handeye_robot_adjustable.launch.py \
    camera_x_offset:=0.05

# 同时调整多个方向
ros2 launch ur5e_gripper_moveit_config view_handeye_robot_adjustable.launch.py \
    camera_x_offset:=0.05 \
    camera_y_offset:=-0.04 \
    camera_z_offset:=-0.03
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;调整姿态&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 调整俯仰角（向下倾斜 30°）
ros2 launch ur5e_gripper_moveit_config view_handeye_robot_adjustable.launch.py \
    camera_pitch:=-1.05

# 同时调整姿态
ros2 launch ur5e_gripper_moveit_config view_handeye_robot_adjustable.launch.py \
    camera_roll:=0.0 \
    camera_pitch:=-1.57 \
    camera_yaw:=1.57
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;完整 6D 位姿调整&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;ros2 launch ur5e_gripper_moveit_config view_handeye_robot_adjustable.launch.py \
    camera_x_offset:=0.05 \
    camera_y_offset:=-0.04 \
    camera_z_offset:=-0.03 \
    camera_roll:=0.0 \
    camera_pitch:=-1.5708 \
    camera_yaw:=1.5708
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;参数说明&lt;/h2&gt;
&lt;h3&gt;位置参数（单位：米）&lt;/h3&gt;
&lt;p&gt;| 参数         | 说明                   | 正方向 |
|--------------|------------------------|--------|
| camera_x_offset | X轴偏移（机械臂前侧） | 前方   |
| camera_y_offset | Y轴偏移（机械臂左侧） | 左侧   |
| camera_z_offset | Z轴偏移（垂直向上）   | 上方   |&lt;/p&gt;
&lt;h3&gt;姿态参数（单位：弧度）&lt;/h3&gt;
&lt;p&gt;| 参数        | 说明   | 旋转轴 |
|-------------|--------|--------|
| camera_roll  | 翻滚角 | X轴    |
| camera_pitch | 俯仰角 | Y轴    |
| camera_yaw   | 偏航角 | Z轴    |&lt;/p&gt;
&lt;h3&gt;常用角度对照表&lt;/h3&gt;
&lt;p&gt;| 角度 | 弧度值 | 用途         |
|------|--------|--------------|
| 0°   | 0.0    | 水平         |
| 30°  | 0.52   | 轻微倾斜     |
| 45°  | 0.79   | 中等倾斜     |
| 90°  | 1.57   | 垂直         |
| -90° | -1.57  | 垂直（反向） |
| 180° | 3.14   | 反向         |&lt;/p&gt;
&lt;h2&gt;技术原理&lt;/h2&gt;
&lt;h3&gt;Xacro 参数传递流程&lt;/h3&gt;
&lt;pre&gt;&lt;code&gt;命令行参数 (Launch)
    ↓
launch 文件 (DeclareLaunchArgument)
    ↓
xacro 命令 (camera_x_offset:=0.05)
    ↓
URDF 实例文件 (&amp;#x3C;xacro:arg&gt;)
    ↓
宏定义 (&amp;#x3C;xacro:macro params&gt;)
    ↓
相机位姿 (&amp;#x3C;origin xyz=&quot;...&quot; rpy=&quot;...&quot;/&gt;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;优势对比&lt;/h2&gt;
&lt;p&gt;| 方案   | 传统方法                    | 本方案                      |
|--------|-----------------------------|-----------------------------|
| 调整速度 | 慢（修改文件→编译→启动）   | 快（命令行参数实时生效）   |
| 便利性 | 需要编辑 XML 文件           | 一行命令搞定               |
| 可逆性 | 需要手动记录旧值           | 默认值始终保留             |
| 适用场景 | 固定位姿                   | 快速迭代调试               |&lt;/p&gt;
&lt;h2&gt;调试技巧&lt;/h2&gt;
&lt;h3&gt;1. 使用 TF 树验证位姿&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 启动后查看 TF 树
ros2 run tf2_tools view_frames

# 检查相机相对于 tool0 的变换
ros2 run tf2_ros tf2_echo tool0 hand_eye_camera_color_optical_frame
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2. 在 RViz 中可视化&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;添加 TF 显示&lt;/li&gt;
&lt;li&gt;设置 Fixed Frame 为 &lt;code&gt;tool0&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;观察 &lt;code&gt;hand_eye_camera_color_optical_frame&lt;/code&gt; 位置&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;3. 增量式调整&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 先粗调（大步长）
camera_y_offset:=-0.05

# 再精调（小步长 0.01）
camera_y_offset:=-0.04
camera_y_offset:=-0.045
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;注意事项&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;参数单位&lt;/strong&gt;: 位置用米，姿态用弧度&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;坐标系&lt;/strong&gt;: 相对于 tool0（末端法兰）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;重新编译&lt;/strong&gt;: 修改 URDF 参数后需要 &lt;code&gt;colcon build&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;默认值&lt;/strong&gt;: 在 &lt;code&gt;ur5e_gripper_handeye.urdf.xacro&lt;/code&gt; 中设置&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;扩展应用&lt;/h2&gt;
&lt;p&gt;此方案同样适用于：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;其他传感器位姿调试（激光雷达、IMU 等）&lt;/li&gt;
&lt;li&gt;末端执行器位姿调整&lt;/li&gt;
&lt;li&gt;夹爪位置微调&lt;/li&gt;
&lt;li&gt;任何需要固定安装在机器人上的组件&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;通过 &lt;strong&gt;xacro 参数化 + launch 文件传递&lt;/strong&gt; 的组合，我们实现了一个灵活、高效的相机位姿调试方案。开发者可以在不修改源代码的情况下，通过命令行参数实时调整相机的 6D 位姿，大大提高了调试效率。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>模型训练（五）分布式训练之DP、DDP、DeepSpeed ZeRO技术</title><link>https://astro-pure.js.org/blog/distri_trainning/distri_trainning-5</link><guid isPermaLink="true">https://astro-pure.js.org/blog/distri_trainning/distri_trainning-5</guid><description>分布式训练</description><pubDate>Thu, 05 Feb 2026 18:55:00 GMT</pubDate><content:encoded>&lt;p&gt;本章节参考&lt;a href=&quot;https://www.bilibili.com/video/BV1mm42137X8/?spm_id_from=333.337.search-card.all.click&amp;#x26;vd_source=52455a50a39ab9ee183496a6de048a09&quot;&gt;UP主&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;分布式训练是为了解决大规模深度学习模型在单设备上无法容纳或训练过慢的问题，通过将计算和存储负载分散到多个 GPU 或机器上协同完成训练。&lt;/p&gt;
&lt;p&gt;其中，DP (DataParallel) 是 PyTorch 早期的单进程多线程方案，将一个 batch 的数据分发到多个 GPU 并行前向/反向，但梯度汇总和参数更新由主 GPU（如 cuda:0）完成，存在通信瓶颈且不支持多机，现已基本弃用。
DDP (DistributedDataParallel) 是当前主流的多进程分布式训练方法，每个 GPU 独立运行一个进程，拥有完整的模型副本，通过高效的 All-Reduce 操作同步梯度，支持单机多卡和多机多卡，扩展性好、性能高。
而 DeepSpeed ZeRO（Zero Redundancy Optimizer）则从内存优化角度出发，在数据并行基础上对优化器状态、梯度和模型参数进行分片存储，显著降低显存占用，使百亿甚至千亿参数模型能在普通 GPU 集群上训练。&lt;/p&gt;
&lt;p&gt;三者中，DDP 侧重通信效率，ZeRO 侧重内存节省，实践中常结合使用（如 DDP + ZeRO）以兼顾速度与规模。&lt;/p&gt;
&lt;h1&gt;Data Parallel (DP)&lt;/h1&gt;
&lt;p&gt;DP是从硬盘读取数据，通过一个CPU进程将数据集划分成多份，然后分发给每个GPU独立进行前向和反向传播，计算出各自的梯度，然后各GPU将各自的梯度传到GPU0上。GPU0通过对这些梯度求和取平均更新自己模型的参数，然后广播给其他GPU。
如图：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F1.png%3ForigWidth%3D1426%26origHeight%3D827%26origFormat%3Dpng&amp;#x26;w=1426&amp;#x26;h=827&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
在分布式训练里，可能一半的时间都花在多卡之间的通信上，下面分析一下DP的通信量：&lt;/p&gt;
&lt;p&gt;假设模型参数量为$\psi$，GPU节点数为$N$。
则对于GPU0，传入梯度为：$(N-1)\psi$，传出参数为$(N-1)\psi$
对于其他GPU：传出梯度为$\psi$，传入参数为$\psi$。&lt;/p&gt;
&lt;h2&gt;存在的问题&lt;/h2&gt;
&lt;p&gt;| 问题类别 | 具体描述 |
|----------|---------|
| &lt;strong&gt;单进程多线程架构 → GIL 和通信瓶颈&lt;/strong&gt; | - DP 在一个主进程中启动多个线程，每个线程控制一个 GPU。- 受 Python &lt;strong&gt;GIL&lt;/strong&gt;（全局解释器锁）限制，CPU 密集型操作无法真正并行。- 所有梯度必须先汇总到&lt;strong&gt;主 GPU&lt;/strong&gt;（通常是 &lt;code&gt;cuda:0&lt;/code&gt;），再由主 GPU 广播更新后的参数。- 导致主 GPU 成为&lt;strong&gt;通信与计算瓶颈&lt;/strong&gt;。 |
| &lt;strong&gt;不支持多机训练&lt;/strong&gt; | - DP 仅支持&lt;strong&gt;单机多卡&lt;/strong&gt;，无法跨机器（multi-node）扩展。- 不能用于大规模分布式集群训练，限制了模型规模和数据吞吐能力。 |
| &lt;strong&gt;显存浪费严重&lt;/strong&gt; | - 每个 GPU 都保存完整的模型副本（与 DDP 相同）。- &lt;strong&gt;主 GPU 额外存储所有其他 GPU 的梯度副本&lt;/strong&gt;用于聚合。- 例如：8 卡训练时，主 GPU 显存占用 ≈ 其他卡的 &lt;strong&gt;2 倍&lt;/strong&gt;，极易发生 &lt;strong&gt;OOM&lt;/strong&gt;（Out-Of-Memory）。 |
| &lt;strong&gt;同步效率低&lt;/strong&gt; | - 每个 batch 的前向/反向完成后，必须等待&lt;strong&gt;所有 GPU 完成计算&lt;/strong&gt;，才能在主 GPU 上聚合梯度。- 采用&lt;strong&gt;同步阻塞式通信&lt;/strong&gt;，任一“慢卡”（straggler）会拖慢整体训练速度。- 虽使用 NCCL 或 CUDA P2P 通信，但数据路径&lt;strong&gt;必须经过主 GPU&lt;/strong&gt;，导致带宽利用率低下。 |&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;Distributed Data Parallel（DDP）&lt;/h1&gt;
&lt;p&gt;介绍DDP之前，先了解一下&lt;strong&gt;Ring-AllReduce&lt;/strong&gt;：Ring-AllReduce 是一种高效的分布式梯度同步算法，广泛应用于大规模深度学习训练中（如 PyTorch DDP、Horovod 等）。其核心思想是将参与训练的多个 GPU（或节点）组织成一个逻辑环形拓扑（ring），每个设备只与左右两个邻居通信，通过多轮分块传输逐步完成全局梯度的聚合。&lt;/p&gt;
&lt;p&gt;整个过程分为两个阶段：Scatter-Reduce（各设备将本地梯度分块，沿环传递并累加，最终每块由一个设备持有完整和）和 All-Gather（将累加后的梯度块沿环广播，使所有设备获得完整的全局梯度）。&lt;/p&gt;
&lt;p&gt;相比传统的中心化聚合（如 Parameter Server 或 DP 的主 GPU 汇总），Ring-AllReduce 消除了单点瓶颈，通信量与设备数无关（带宽最优），且能充分利用多设备间的并行通信能力。尤其在高速互联网络（如 NVLink、InfiniBand）下，它能实现接近线性的扩展效率，是现代大模型分布式训练的通信基石。&lt;/p&gt;
&lt;p&gt;&lt;/p&gt;
&lt;p&gt;首先将GPU0的参数$a_0$传给GPU1，GPU1的参数$b_1$传给GPU2，GPU2的参数$c_2$传给GPU0，形成一个闭环，然后GPU与各自传入的参数相加，得到下面的结果。
&lt;/p&gt;
&lt;p&gt;然后将GPU0的参数$c_0+c_2$传给GPU1，GPU1的参数$a_0+a_1$传给GPU2，GPU2的参数$b_1+b_2$传给GPU0，得到下面的结果。&lt;/p&gt;
&lt;p&gt;&lt;/p&gt;
&lt;p&gt;以此类推，最终得到：
&lt;/p&gt;
&lt;p&gt;从而实现了所有GPU上都有聚合后的模型参数，在这个过程中，每个GPU同时发送和接收，可以最大限度利用每个显卡的上下行带宽。&lt;/p&gt;
&lt;p&gt;与DP不同，DDP是多进程的，每个进程为自己的GPU准备数据并和其他GPU通信，每个GPU用自己的模型进行前向和反向传播，因为每个GPU的数据都不同，所以梯度也不同，最后通过Ring-AllReduce实现梯度的同步与聚合，每个GPU通过更新梯度来更新模型：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F6.png%3ForigWidth%3D1495%26origHeight%3D973%26origFormat%3Dpng&amp;#x26;w=1495&amp;#x26;h=973&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
可以更深入理解DDP，在反向传播过程中，越后面的层数，梯度最先得到，但是每计算一个梯度就进行Ring-AllReduce会大幅增加开销，所以会使用桶，计算的梯度会保存在桶中。具体来说，会对每个层注册监听器（为了让DDP框架知道哪些梯度计算好了），当所有GPU的同一个桶都装满了，会执行一次Ring-AllReduce同步，在这个过程中GPU还在计算其他的梯度，当所有的桶都梯度同步了，每个GPU会调用自己的优化器进行更新：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F7.png%3ForigWidth%3D1766%26origHeight%3D644%26origFormat%3Dpng&amp;#x26;w=1766&amp;#x26;h=644&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
下面分析一下DDP通信量：
假设模型参数量为$\psi$，GPU节点数为$N$。
对于每一个GPU进程：
Scatter-Reduce阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
AllGather阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
总传入/传出：$2\psi$
与集群大小无关&lt;/p&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;| 问题类别 | 具体描述 |
|----------|---------|
| &lt;strong&gt;优点&lt;/strong&gt; | - 基于多进程架构，避免 Python GIL 限制，训练更稳定高效。- 使用高效的 All-Reduce（如 NCCL）进行梯度同步，通信带宽利用率高。- 支持单机多卡和多机多卡训练，扩展性强。- 各 GPU 显存占用均衡，无主 GPU 瓶颈。- 与 &lt;code&gt;DistributedSampler&lt;/code&gt;、&lt;code&gt;torchrun&lt;/code&gt; 等标准工具链无缝集成，是 PyTorch 官方推荐方案。 |
| &lt;strong&gt;缺点&lt;/strong&gt; | - 每个进程需加载完整模型副本，显存开销大，难以直接训练超大规模模型（如百亿参数以上）。- 需要手动处理 checkpoint 加载/保存、数据分片、随机种子同步等细节，代码复杂度高于 DP。- 对网络拓扑和通信库（如 NCCL）依赖较强，在低带宽或多机环境下可能成为瓶颈。- 不具备内置的显存优化机制（如参数分片），需结合 ZeRO 或 FSDP 才能进一步节省内存。 |&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;DeepSpeed ZeRO&lt;/h1&gt;
&lt;p&gt;DeepSpeed ZeRO（Zero Redundancy Optimizer）是微软提出的一种显存优化技术，旨在消除传统数据并行训练中的内存冗余。它通过将优化器状态、梯度和模型参数在多个 GPU 之间进行分片存储，显著降低每个设备的显存占用。ZeRO 分为三个阶段：ZeRO-1 分片优化器状态，ZeRO-2 增加梯度分片，ZeRO-3 进一步分片模型参数，使得百亿甚至千亿参数的大模型能在普通 GPU 集群上高效训练。该技术可与分布式训练框架（如 PyTorch DDP）结合使用，在几乎不损失训练速度的前提下，大幅提升模型规模的可扩展性。&lt;/p&gt;
&lt;p&gt;正常情况下每个GPU上要有Data、模型参数和梯度以及优化器的参数和梯度，下面使用的优化器是Adam，所以要存储一阶动量和二阶动量：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F8.png%3ForigWidth%3D1883%26origHeight%3D1042%26origFormat%3Dpng&amp;#x26;w=1883&amp;#x26;h=1042&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;ZeRO-1&lt;/h2&gt;
&lt;p&gt;由上图可知，占用显存最多的就是优化器状态，因为它是FP32精度，并且每个GPU都需要存储，所以ZeRO-1对优化器状态进行了分片。
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F9.png%3ForigWidth%3D1876%26origHeight%3D1037%26origFormat%3Dpng&amp;#x26;w=1876&amp;#x26;h=1037&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
这里假设有一个9层的网络结构，每个分片优化器负责管理3层网络结构。当反向传播完成了倒数前3层的梯度计算时，此时GPU0和GPU1并没有计算这些层的优化器，所以需要传给GPU2的优化器进行梯度的求和取平均，当所有的GPU都拿到对应层的梯度时，将梯度值转换为FP32梯度（防止下溢），然后更新对应层的参数。最后，每个GPU将自己更新的那部分层的参数广播给其他GPU，这样就完成了一个epoch的训练。&lt;/p&gt;
&lt;h3&gt;通讯量分析&lt;/h3&gt;
&lt;p&gt;假设模型参数量为$\psi$，GPU节点数为$N$。
对于每一个GPU进程：
梯度收集阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
参数广播阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
总传入/传出：$2\psi$
与DDP通讯量相同。&lt;/p&gt;
&lt;h2&gt;ZeRO-2&lt;/h2&gt;
&lt;p&gt;ZeRO-2的思路更ZeRO-1一样，只是将梯度进行了分片。
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F10.png%3ForigWidth%3D1878%26origHeight%3D1048%26origFormat%3Dpng&amp;#x26;w=1878&amp;#x26;h=1048&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
在ZeRO-1的基础上，每个GPU只维护了自己负责层的梯度值，其他操作都是一致的，进一步节省了显存。&lt;/p&gt;
&lt;h3&gt;通讯量分析&lt;/h3&gt;
&lt;p&gt;假设模型参数量为$\psi$，GPU节点数为$N$。
对于每一个GPU进程：
梯度收集阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
参数广播阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
总传入/传出：$2\psi$
与DDP通讯量相同。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;ZeRO-3&lt;/h2&gt;
&lt;p&gt;ZeRO-3 进一步分片模型参数。在前向传播过程中，由于每个GPU维护不同层的参数，前向传播开始时，GPU1和GPU2没有前3层的参数，所以需要通过GPU0广播分发参数给他们进行计算，GPU1和GPU2计算完之后就把这些参数丢弃，从而节省了显存。
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F11.png%3ForigWidth%3D1879%26origHeight%3D1002%26origFormat%3Dpng&amp;#x26;w=1879&amp;#x26;h=1002&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
前向传播结束后，进行反向传播。同样的，反向传播开始时，GPU0和GPU1没有最后3层的参数，所以需要通过GPU2广播分发参数给他们进行计算，计算完之后就丢弃，节省缓存。
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F12.png%3ForigWidth%3D1871%26origHeight%3D1043%26origFormat%3Dpng&amp;#x26;w=1871&amp;#x26;h=1043&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h3&gt;通讯量分析&lt;/h3&gt;
&lt;p&gt;假设模型参数量为$\psi$，GPU节点数为$N$。
对于每一个GPU进程：
梯度收集阶段传入/传出：$(N - 1) \frac{\psi}{N} \approx \psi$
参数广播阶段传入/传出：$2*(N - 1) \frac{\psi}{N} \approx 2\psi$
总传入/传出：$3\psi$
是DDP传输量的1.5倍。&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;对比&lt;/h1&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-5%2F13.png%3ForigWidth%3D1537%26origHeight%3D707%26origFormat%3Dpng&amp;#x26;w=1537&amp;#x26;h=707&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;| 特性 / 阶段 | &lt;strong&gt;ZeRO-1&lt;/strong&gt; | &lt;strong&gt;ZeRO-2&lt;/strong&gt; | &lt;strong&gt;ZeRO-3&lt;/strong&gt; |
|-------------|-----------|-----------|-----------|
| &lt;strong&gt;分片对象&lt;/strong&gt; | 优化器状态（如 Adam 的动量、方差） | 优化器状态 + &lt;strong&gt;梯度&lt;/strong&gt; | 优化器状态 + 梯度 + &lt;strong&gt;模型参数&lt;/strong&gt; |
| &lt;strong&gt;显存节省效果&lt;/strong&gt; | 显存占用 ≈ 原始 DP 的 &lt;strong&gt;1/N&lt;/strong&gt;（仅优化器部分） | 进一步降低，梯度不再全量存储 | &lt;strong&gt;最大节省&lt;/strong&gt;：每个 GPU 仅存 1/N 的参数、梯度、优化器状态 |
| &lt;strong&gt;适用场景&lt;/strong&gt; | 中等规模模型训练（如 1B~10B） | 大模型训练（如 10B~30B） | 超大模型训练（如 70B+） |
| &lt;strong&gt;通信开销&lt;/strong&gt; | 总通信量 ≈ &lt;strong&gt;2ψ&lt;/strong&gt;（与 DDP 相同） | 总通信量 ≈ &lt;strong&gt;2ψ&lt;/strong&gt;（与 DDP 相同） | 总通信量 ≈ &lt;strong&gt;3ψ&lt;/strong&gt;（比 DDP 高 &lt;strong&gt;50%&lt;/strong&gt;） |
| &lt;strong&gt;计算流程复杂度&lt;/strong&gt; | 低：仅在优化器更新时通信 | 中：反向后需聚合梯度 | &lt;strong&gt;高&lt;/strong&gt;：前向/反向均需动态收集/释放参数 |
| &lt;strong&gt;是否支持超大模型&lt;/strong&gt; | ❌ 单卡仍需存完整模型 | ❌ 单卡仍需存完整模型 | ✅ 单卡只需存模型的一部分 |
| &lt;strong&gt;主要优点&lt;/strong&gt; | - 简单易用- 显著减少优化器显存（占大头）- 通信开销不变 | - 在 ZeRO-1 基础上进一步省显存- 适合更大 batch size | - 实现极致显存压缩- 可在普通 GPU 上训练千亿模型- 支持模型并行替代方案 |
| &lt;strong&gt;主要缺点&lt;/strong&gt; | - 模型参数仍全量复制，无法突破单卡模型容量限制 | - 同上，参数仍需全量存储 | - &lt;strong&gt;通信开销显著增加&lt;/strong&gt;- 参数频繁 gather/scatter 引入延迟- 实现复杂，对调度要求高 |&lt;/p&gt;
&lt;hr&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>模型训练（四）梯度累计Gradient Accumulation</title><link>https://astro-pure.js.org/blog/distri_trainning/distri_trainning-4</link><guid isPermaLink="true">https://astro-pure.js.org/blog/distri_trainning/distri_trainning-4</guid><description>Gradient Accumulation</description><pubDate>Thu, 05 Feb 2026 18:53:00 GMT</pubDate><content:encoded>&lt;p&gt;梯度累计（Gradient Accumulation）是一种在显存受限情况下模拟大批次（&lt;code&gt;large batch size&lt;/code&gt;）训练的技术。它的核心思想是：用多次小 &lt;code&gt;batch&lt;/code&gt; 的前向/反向计算，累积梯度，对梯度除以 &lt;code&gt;batch size&lt;/code&gt;得到平均梯度，然后一次性更新模型参数，从而在不增加显存占用的前提下，获得大 &lt;code&gt;batch&lt;/code&gt; 训练的优化效果。&lt;/p&gt;
&lt;p&gt;通常，深度学习训练中：前向传播 + 反向传播 会为一个 batch 计算梯度；优化器立即用该梯度更新模型参数。但当 batch size 很大时，中间激活和梯度会占用大量显存，可能超出 GPU 显存。当batch size很小，比如为1时不会超过GPU显存，它会每训练一个样本，然后计算一次梯度，由于样本之间的差异很大，导致更新的梯度忽大忽小不可控，导致训练过程不稳定、收敛缓慢，甚至无法收敛。&lt;/p&gt;
&lt;p&gt;例如，下面是一个前向传播的计算图：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-4%2F1.png%3ForigWidth%3D601%26origHeight%3D281%26origFormat%3Dpng&amp;#x26;w=601&amp;#x26;h=281&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
其中，&lt;code&gt;a,b,c&lt;/code&gt;都是参数，他们会在反向传播的过程中会进行更新。更新的过程如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-4%2F2.png%3ForigWidth%3D643%26origHeight%3D412%26origFormat%3Dpng&amp;#x26;w=643&amp;#x26;h=412&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
通过链式法则，可以得到损失函数对参数&lt;code&gt;b&lt;/code&gt;的梯度，然后&lt;code&gt;b-lr*该梯度&lt;/code&gt;作为新的&lt;code&gt;b&lt;/code&gt;，再继续进行前向传播、反向传播更新，&lt;code&gt;a&lt;/code&gt;和&lt;code&gt;c&lt;/code&gt;同理。在这个过程中，模型的参数例如&lt;code&gt;a,b,c&lt;/code&gt;等，称为&lt;code&gt;叶子节点张量&lt;/code&gt;，其梯度会被累积到&lt;code&gt;.grad&lt;/code&gt;属性中，并长期驻留现存，直到下一次&lt;code&gt;optimizer.zero_grad()&lt;/code&gt;清空，对于&lt;code&gt;中间非叶子张量&lt;/code&gt;，例如&lt;code&gt;v&lt;/code&gt;，其值等于&lt;code&gt;v=b*c&lt;/code&gt;，其梯度默认不会保留，计算完后立即释放。&lt;/p&gt;
&lt;p&gt;首先准备训练集和标签：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;定义一个模型：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class MLP(torch.nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP,self).__init__()
        self.fc1=torch.nn.Linear(input_size, hidden_size)
        self.fc2=torch.nn.Linear(hidden_size, output_size)

    def forward(self,x):
        out=self.fc1(x)
        out=torch.relu(out)
        out=self.fc2(out)
        return out
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来实现梯度累计：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model=MLP(input_size=64,hidden_size=256,output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

iter=0
accum_steps=0
while True:
    out=model(x)
    loss=loss_fn(out,y)
    loss=loss/4    
    loss.backward()
    accum_steps+=1
    if accum_steps==4:
        optimizer.step()
        optimizer.zero_grad()
        accum_steps=0
        iter+=1
        if iter%25000==0:
            print(f&apos;iter={iter} loss={loss.item()} cuda_mem={torch.cuda.memory_allocated()}Bytes&apos;)
        if loss.item()&amp;#x3C;=1e-3:
            break
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;accum_steps&lt;/code&gt;用于梯度累计计数，其值为&lt;code&gt;4&lt;/code&gt;时，更新一次参数。需要注意&lt;code&gt;loss=loss/4&lt;/code&gt;，这是因为前面累积了4个&lt;code&gt;batch&lt;/code&gt;的梯度，在更新的时候，应该要取&lt;code&gt;batch size&lt;/code&gt;的均值进行更新。使用&lt;code&gt;loss=loss/4&lt;/code&gt;并不影响计算图，可以理解为在loss后面增加&lt;code&gt;new_loss=loss*(1/4)&lt;/code&gt;，当求梯度&lt;code&gt;d_new_loss/d_loss&lt;/code&gt;时，值就是&lt;code&gt;1/4&lt;/code&gt;，从而实现对&lt;code&gt;loss&lt;/code&gt;乘以&lt;code&gt;1/4&lt;/code&gt;。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>模型训练（三）激活值检查点Activation Checkpoint</title><link>https://astro-pure.js.org/blog/distri_trainning/distri_trainning-3</link><guid isPermaLink="true">https://astro-pure.js.org/blog/distri_trainning/distri_trainning-3</guid><description>Activation Checkpoint</description><pubDate>Thu, 05 Feb 2026 18:52:00 GMT</pubDate><content:encoded>&lt;p&gt;激活值检查点（Activation Checkpointing）是一种用于减少深度神经网络训练过程中显存占用的技术。在标准的反向传播中，所有中间层的激活值（即前向传播的输出）都需要保存在显存中，以便计算梯度时使用。对于深层网络或大批次训练，这些激活值会消耗大量显存。&lt;/p&gt;
&lt;p&gt;激活值检查点的核心思想是：不在前向传播时保存所有中间激活值，而只保存部分关键层的输出（称为“检查点”）；在反向传播需要某段中间激活时，临时从最近的检查点重新执行前向计算来恢复，用时间换空间。&lt;/p&gt;
&lt;p&gt;这种方法显著降低了显存需求（通常可节省30%~70%），代价是增加了少量计算开销（因为部分前向过程需重复执行）。它特别适用于训练非常深的模型（如Transformer、ResNet-152等）或在有限显存设备上进行大模型训练。PyTorch通过 torch.utils.checkpoint.checkpoint 提供了对该技术的原生支持。&lt;/p&gt;
&lt;p&gt;首先加载训练数据和标签：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后定义模型：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class MLP(torch.nn.Module):
    def __init__(self,input_size,hidden_sizes,output_size):
        super(MLP,self).__init__()
        self.fc_first=torch.nn.Linear(input_size, hidden_sizes[0])
        fc_middle=[]
        for i in range(1,len(hidden_sizes)-1):
            fc_middle.append(torch.nn.Linear(hidden_sizes[i-1],hidden_sizes[i]))
            fc_middle.append(torch.nn.ReLU())
        self.fc_middle=torch.nn.Sequential(*fc_middle)
        self.fc_final=torch.nn.Linear(hidden_sizes[-1], output_size)
    
    def forward(self,x,checkpoint=False):
        out=self.fc_first(x)
        out=torch.relu(out)
        if checkpoint:
            out=torch.utils.checkpoint.checkpoint(lambda x:self.fc_middle(x),out,use_reentrant=False)
        else:
            out=self.fc_middle(out)
        out=self.fc_final(out)
        return out
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;注意，在&lt;code&gt;forward&lt;/code&gt;方法中：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if checkpoint:
    out=torch.utils.checkpoint.checkpoint(lambda x:self.fc_middle(x),out,use_reentrant=False)
else:
    out=self.fc_middle(out)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码实现了选择性地对模型中间层（&lt;code&gt;fc_middle&lt;/code&gt;）启用激活值检查点（Activation Checkpointing），以在训练时节省显存。&lt;/p&gt;
&lt;p&gt;具体来说：&lt;/p&gt;
&lt;p&gt;当 &lt;code&gt;checkpoint=False&lt;/code&gt;（默认情况）时，直接执行 &lt;code&gt;out = self.fc_middle(out)&lt;/code&gt;，即按常规方式完成前向传播，所有中间激活值都会被保存在显存中，供后续反向传播使用。
当 &lt;code&gt;checkpoint=True&lt;/code&gt; 时，不直接计算 &lt;code&gt;fc_middle&lt;/code&gt; 的输出并保留其所有中间结果，而是通过 &lt;code&gt;torch.utils.checkpoint.checkpoint&lt;/code&gt; 包装该计算过程。PyTorch 会不在前向传播中保存 &lt;code&gt;fc_middle&lt;/code&gt; 内部各层的激活值，而只保留输入 &lt;code&gt;out&lt;/code&gt;；在反向传播需要这些中间激活时，PyTorch 会临时重新运行 &lt;code&gt;fc_middle&lt;/code&gt; 的前向计算（从保存的输入开始）来重建所需激活值。&lt;/p&gt;
&lt;p&gt;接下来开始训练，指定&lt;code&gt;checkpoint=True&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model=MLP(input_size=64,hidden_sizes=[256,512,512,128],output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

iter=0
while True:
    optimizer.zero_grad()
    
    torch.cuda.reset_peak_memory_stats()  # 👈 重置峰值统计
    
    out = model(x, checkpoint=True)
    loss = loss_fn(out, y)
    loss.backward()
    optimizer.step()

    peak_mem = torch.cuda.max_memory_allocated()  # 👈 获取本次迭代峰值
    iter += 1
    
    if iter % 10000 == 0:
        print(f&apos;iter={iter} loss={loss.item()} peak_cuda_mem={peak_mem} Bytes&apos;)
    
    if loss.item() &amp;#x3C;= 1e-3:
        break
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出结果（消耗内存约43MB）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;iter=10000 loss=0.020305516198277473 peak_cuda_mem=45377536 Bytes
iter=20000 loss=0.003331078216433525 peak_cuda_mem=45377536 Bytes
iter=30000 loss=0.0013336377451196313 peak_cuda_mem=45377536 Bytes
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;指定&lt;code&gt;checkpoint=False&lt;/code&gt;时：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model=MLP(input_size=64,hidden_sizes=[256,512,512,128],output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
iter=0
while True:
    optimizer.zero_grad()
    
    torch.cuda.reset_peak_memory_stats()  # 👈 重置峰值统计
    
    out = model(x, checkpoint=False)
    loss = loss_fn(out, y)
    loss.backward()
    optimizer.step()

    peak_mem = torch.cuda.max_memory_allocated()  # 👈 获取本次迭代峰值
    iter += 1
    
    if iter % 10000 == 0:
        print(f&apos;iter={iter} loss={loss.item()} peak_cuda_mem={peak_mem} Bytes&apos;)
    
    if loss.item() &amp;#x3C;= 1e-3:
        break
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出结果（消耗内存约48MB）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;iter=10000 loss=0.019843708723783493 peak_cuda_mem=50066432 Bytes
iter=20000 loss=0.0028206356801092625 peak_cuda_mem=50066432 Bytes
iter=30000 loss=0.0011666719801723957 peak_cuda_mem=50066432 Bytes
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>模型训练（二）AMP自动混合精度训练</title><link>https://astro-pure.js.org/blog/distri_trainning/distri_trainning-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/distri_trainning/distri_trainning-1</guid><description>Automatic Mixed Precision</description><pubDate>Thu, 05 Feb 2026 18:50:00 GMT</pubDate><content:encoded>&lt;p&gt;DDP（分布式数据并行）​ 是PyTorch的分布式训练框架，用于在多个GPU/机器上并行训练模型，大幅提升训练速度&lt;/p&gt;
&lt;p&gt;它的核心思想是：将大批次数据拆分到多个GPU上，每个GPU计算部分梯度，然后聚合梯度更新模型，确保每个模型的权重都是一样的，然后再进行下一个epoch的训练。&lt;/p&gt;
&lt;p&gt;流程如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;主进程：        GPU0 (rank=0)        GPU1 (rank=1)        GPU2 (rank=2)
    │              │                    │                    │
    ├─ 初始化 ──→ 相同模型副本 ←───── 相同模型副本 ←───── 相同模型副本
    │              │                    │                    │
    ├─ 数据分发 ─→ 数据分片1          数据分片2          数据分片3
    │              │                    │                    │
    ├─ 并行计算 ─→ 前向+反向           前向+反向           前向+反向
    │              │                    │                    │
    ├─ 梯度同步 ──→ 梯度平均 ←───────── 梯度平均 ←───────── 梯度平均
    │              │                    │                    │
    └─ 参数更新 ─→ 更新参数 ←────────── 更新参数 ←────────── 更新参数
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;定义一个简单的神经网络：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class Net(nn.Module):   # 模型定义
    def __init__(self):
        super(Net,self).__init__() 
        self.flatten=nn.Flatten()
        self.seq=nn.Sequential(
            nn.Linear(28*28,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,10)
        )

    def forward(self,x):
        x=self.flatten(x)
        return self.seq(x)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来需要建立多GPU/多机器间的通信网络，让所有训练进程能够互相识别和通信，这是是DDP分布式训练的核心初始化部分。&lt;/p&gt;
&lt;p&gt;首先初始化进程组：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;dist.init_process_group(backend=&apos;nccl&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;启动分布式环境，建立进程间通信。backend=&apos;nccl&apos;：使用NVIDIA的NCCL通信库（GPU间高速通信），其他可选backend：&apos;gloo&apos;（CPU）、&apos;mpi&apos;（高性能计算）。&lt;/p&gt;
&lt;p&gt;执行后的效果就是：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;启动前：4个独立的Python进程（互不相识）
启动后：4个进程组成通信组，可以互相发送数据
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来获取进程排名：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;rank = dist.get_rank()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;获取当前进程在组内的唯一标识符。rank=0：主进程（master），通常负责保存模型、日志等&lt;/p&gt;
&lt;p&gt;然后获取进程总数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;world_size = dist.get_world_size()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个值是在启动python脚本时设置的，例如&lt;code&gt;torchrun --nproc_per_node=4 train.py&lt;/code&gt;，则&lt;code&gt;world_size=4&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;加载&lt;code&gt;checkpoint&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;checkpoint=None # 各自加载checkpoint
    try:
        checkpoint=torch.load(&apos;checkpoint.pth&apos;,map_location=&apos;cpu&apos;)   # checkpoint是cuda:0保存的，加载默认会读到cuda:0，所以明确指定给cpu
    except:
        pass
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;安全地加载之前保存的训练状态，如果文件不存在或损坏，则继续训练。&lt;/p&gt;
&lt;p&gt;加载模型及其权重参数到主进程&lt;code&gt;rank0&lt;/code&gt;上：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model=Net().to(device_name)
    if checkpoint and rank==0:  # rank0恢复模型参数
        model.load_state_dict(checkpoint[&apos;model&apos;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后rank0广播给其他的进程：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model=DDP(model) # 【集合通讯】rank0广播参数给其他进程
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;优化器的初始化和状态恢复：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;optimizer=torch.optim.Adam(model.parameters(),lr=0.001) #model参数一致，则optim会保证其初始状态一致
if checkpoint:
   optimizer.load_state_dict(checkpoint[&apos;optimizer&apos;])  # 各自加载checkpoint
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;保证所有GPU上的优化器具有相同的初始状态。&lt;/p&gt;
&lt;p&gt;加载训练集并分片发给其他进程：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;train_dataset=MNIST(root=&apos;./data&apos;,download=True,transform=ToTensor(),train=True) # 各自加载dataset
sampler=DistributedSampler(train_dataset) # 指派子集给各进程
train_dataloader=DataLoader(train_dataset,batch_size=32,sampler=sampler,persistent_workers=True,num_workers=2)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;验证集只需要在主进程上进行一轮验证即可，不用分发：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;val_dataset=MNIST(root=&apos;./data&apos;,download=True,transform=ToTensor(),train=False)
    val_dataloader=DataLoader(val_dataset,batch_size=32,shuffle=True,persistent_workers=True,num_workers=2)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来进行训练的循环：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;for epoch in range(20):
    sampler.set_epoch(epoch)    # 【集合通讯】生成随机种子，rank0广播给其他进程
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;sampler.set_epoch(epoch)&lt;/code&gt;能确保各个进程获取到的数据片不一样，如果没有这个代码，会导致模型在每个epoch看到完全相同的数据顺序​ → 过拟合风险：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 结果：
Epoch 0: GPU0处理[0,4,8,...], GPU1处理[1,5,9,...], GPU2处理[2,6,10,...], GPU3处理[3,7,11,...]
Epoch 1: GPU0处理[0,4,8,...], GPU1处理[1,5,9,...], GPU2处理[2,6,10,...], GPU3处理[3,7,11,...]  # 相同！
Epoch 2: GPU0处理[0,4,8,...], GPU1处理[1,5,9,...], GPU2处理[2,6,10,...], GPU3处理[3,7,11,...]  # 相同！
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来就是模型训练部分：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model.train()
for x,y in train_dataloader:
    x,y=x.to(device_name),y.to(device_name)
    pred_y=model(x) # 【集合通讯】rank0广播model buffer给其他进程
    loss=F.cross_entropy(pred_y,y)
    optimizer.zero_grad()
    loss.backward() # 【集合通讯】每个参数的梯度做all reduce（每个进程会收到其他进程的梯度，并求平均）
    optimizer.step()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中前向传播时&lt;code&gt;pred_y=model(x)&lt;/code&gt;，DDP会自动同步模型参数。反向传播时&lt;code&gt;loss.backward()&lt;/code&gt;，DDP也会自动同步梯度。&lt;/p&gt;
&lt;p&gt;然后执行：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;ist.reduce(loss,dst=0) # 【集合通讯】rank0汇总其他进程的loss
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将所有GPU的损失值汇总到主进程（rank0），用于计算平均损失、记录日志等。举个例子：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;进程0 (rank=0): loss=0.5
进程1 (rank=1): loss=0.3
进程2 (rank=2): loss=0.7  
进程3 (rank=3): loss=0.2

执行 dist.reduce(loss, dst=0) 后：

进程0: loss = 0.5 + 0.3 + 0.7 + 0.2 = 1.7  (汇总结果)
进程1: loss=0.3 (保持不变)
进程2: loss=0.7 (保持不变)
进程3: loss=0.2 (保持不变)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来在主线程&lt;code&gt;rank0&lt;/code&gt;上进行验证和保存模型：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if rank==0:
   train_avg_loss=loss.item()/world_size
   
   # evaluate
   raw_model=model.module
   val_loss=0
   with torch.no_grad():
       for x,y in val_dataloader:
           x,y=x.to(device_name),y.to(device_name)
           pred_y=raw_model(x)
           loss=F.cross_entropy(pred_y,y)
           val_loss+=loss.item()
   val_avg_loss=val_loss/len(val_dataloader)
   print(f&apos;train_loss:{train_avg_loss} val_loss:{val_avg_loss}&apos;)
   
   # checkpoint
   torch.save({&apos;model&apos;:model.module.state_dict(),&apos;optimizer&apos;:optimizer.state_dict()},&apos;.checkpoint.pth&apos;)
   os.replace(&apos;.checkpoint.pth&apos;,&apos;checkpoint.pth&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;特别注意需要执行：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;dist.barrier() # 【集合通讯】等待rank0跑完eval
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它是分布式训练中的同步原语，用于确保所有进程在继续执行前达到同一个执行点。作用是创建一个同步屏障，所有进程必须在此处等待，直到所有进程都到达这个点才能继续执行。如果没有这行代码，其他进程会在&lt;code&gt;rank0&lt;/code&gt;还在执行验证的时候，执行下一个&lt;code&gt;epoch&lt;/code&gt;，导致权重参数不对齐，训练失败！&lt;/p&gt;
&lt;p&gt;最后执行命令训练：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;torchrun --nproc-per-node 8 singlenode_cpu.py
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是在&lt;code&gt;cpu&lt;/code&gt;上开了8个进程进行训练。打印输出如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-1%2F1.png%3ForigWidth%3D468%26origHeight%3D244%26origFormat%3Dpng&amp;#x26;w=468&amp;#x26;h=244&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>模型训练（一）分布式训练之DDP</title><link>https://astro-pure.js.org/blog/distri_trainning/distri_trainning-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/distri_trainning/distri_trainning-2</guid><description>Distribute Data Parallel</description><pubDate>Thu, 05 Feb 2026 18:49:00 GMT</pubDate><content:encoded>&lt;p&gt;AMP（Automatic Mixed Precision，自动混合精度训练）是一种在深度学习训练中加速计算并节省显存的技术，同时几乎不损失模型精度。&lt;/p&gt;
&lt;p&gt;AMP（自动混合精度训练）的工作流程如下：
在训练过程中，AMP 通过 autocast 自动将模型的前向计算切换到半精度（FP16），以加速运算并减少显存占用；同时，为防止 FP16 表示范围有限导致梯度下溢（变为零），它使用 GradScaler 对损失值进行放大（如乘以1024），使反向传播产生的梯度落在 FP16 的有效范围内；随后，这些缩放后的梯度被转换回单精度（FP32），并与优化器中维护的 FP32 主权重结合，在去除缩放因子后完成参数更新。整个过程由框架自动管理哪些操作使用 FP16、哪些必须保留 FP32（如 BatchNorm 或 softmax），从而在几乎不损失模型精度的前提下，显著提升训练速度并降低显存消耗。&lt;/p&gt;
&lt;p&gt;流程图如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fdistri_trainning%2Fdistri_trainning-2%2F1.png%3ForigWidth%3D736%26origHeight%3D310%26origFormat%3Dpng&amp;#x26;w=736&amp;#x26;h=310&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;首先加载手写数字数据集（Digits Dataset）并将其转换为 PyTorch 张量，同时放到 GPU 上进行后续深度学习训练或推理。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)
# torch.Size([1797, 64]) torch.float32
# torch.Size([1797]) torch.int64
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中&lt;code&gt;x/16&lt;/code&gt;是进行归一化操作，原始像素值范围是 0~16（因为 8×8 图像来自 sklearn 的预处理版本，最大值为 16）&lt;/p&gt;
&lt;p&gt;定义一个网络：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class MLP(torch.nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP,self).__init__()
        self.fc1=torch.nn.Linear(input_size, hidden_size)
        self.fc2=torch.nn.Linear(hidden_size, output_size)

    def forward(self,x):
        out=self.fc1(x)
        out=torch.relu(out)
        out=self.fc2(out)
        return out
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后进行模型定义、损失函数、优化器和自动混合精度（AMP）训练组件的初始化：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;model=MLP(input_size=64,hidden_size=256,output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
scaler=torch.amp.GradScaler()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中&lt;code&gt;scaler = torch.amp.GradScaler()&lt;/code&gt;创建一个 梯度缩放器（GradScaler），用于 自动混合精度（AMP）训练。在 FP16（半精度）训练时，防止梯度下溢（underflow → 变成 0）。
&lt;em&gt;&lt;strong&gt;tips：下溢是指一个数太小了，小到当前数据类型无法表示，结果被强制变成 0。由于链式法则，所以可能会导致梯度越来越小。&lt;/strong&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;接下来在模型的某一层（model.fc1）上注册一个前向传播钩子（forward hook），用于在第一次前向计算时打印该层的输入、输出和权重的形状与数据类型，便于调试模型的数据流和精度（如 FP16/FP32）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;print_once=False
def debug_forward(module,input,output):
    global print_once
    if not print_once:
        print_once=True
        print(f&apos;{module}\ninput_shape={input[0].shape} input_dtype={input[0].dtype}\noutput_shape={output.shape} output_dtype={output.dtype}\nweight_shape={module.weight.shape} weight_dtype={module.weight.dtype}&apos;)
    
model.fc1.register_forward_hook(debug_forward)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;打印的内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float16
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它揭示了模型在混合精度训练（AMP）环境下某一层（fc1）的实际运行状态。由&lt;code&gt;output_dtype&lt;/code&gt;可知在训练过程中，将&lt;code&gt;float32&lt;/code&gt;自动转换成&lt;code&gt;float16&lt;/code&gt;进行训练了。模型参数&lt;code&gt;weight_shape&lt;/code&gt;始终以 FP32 形式存储（这是 AMP 的标准做法，称为 “master weights”）。&lt;/p&gt;
&lt;p&gt;最后开始训练：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;iter=0
while True:
    optimizer.zero_grad()
    with torch.amp.autocast(device_type=&apos;cuda&apos;,dtype=torch.float16): # FP16 Mix
        out=model(x)
        loss=loss_fn(out,y)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    iter+=1
    if iter%100000==0:
        print(f&apos;iter={iter} loss={loss.item()} cuda_mem={torch.cuda.memory_allocated()}Bytes&apos;)
    if loss.item()&amp;#x3C;=1e-3:
        break
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中核心是：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;with torch.amp.autocast(device_type=&apos;cuda&apos;, dtype=torch.float16):  # FP16 Mix
    out = model(x)
    loss = loss_fn(out, y)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;在这个上下文中，PyTorch 会自动为支持的操作选择 FP16 或 FP32 精度：大多数计算（如 Linear, Conv, MatMul）→ 使用 FP16（更快、更省显存）；数值敏感操作（如 Softmax, Log, BatchNorm）→ 自动回退到 FP32（保精度）。&lt;/p&gt;
&lt;p&gt;再进行梯度缩放：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;scaler.scale(loss).backward()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;em&gt;&lt;strong&gt;tips：在 FP16 中，梯度可能太小而下溢（underflow → 变成 0）。解决方案：先把 loss 放大（如 ×1024），使得反向传播时梯度也放大，落在 FP16 可表示范围内&lt;/strong&gt;&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;&lt;code&gt;scaler.scale(loss)&lt;/code&gt;会返回一个 scaled loss（FP16），调用 &lt;code&gt;.backward()&lt;/code&gt; 时，计算的是 放大后的梯度。&lt;/p&gt;
&lt;p&gt;最后参数更新：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;scaler.step(optimizer)
scaler.update()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中&lt;code&gt;scaler.step(optimizer)&lt;/code&gt;先检查梯度是否包含 &lt;code&gt;NaN/Inf&lt;/code&gt;（&lt;code&gt;FP16&lt;/code&gt; 容易溢出）。如果正常，则自动将梯度从 &lt;code&gt;FP16&lt;/code&gt; 转回 &lt;code&gt;FP32&lt;/code&gt;，并除以缩放因子，再调用 &lt;code&gt;optimizer.step()&lt;/code&gt;。如果检测到异常（如梯度过大导致 &lt;code&gt;Inf&lt;/code&gt;），则跳过本次更新（避免破坏模型）。
&lt;code&gt;scaler.update()&lt;/code&gt;动态调整下一次的缩放因子（&lt;code&gt;scale&lt;/code&gt;）。如果连续几次都无异常 → 尝试增大 &lt;code&gt;scale&lt;/code&gt;（更激进）。如果出现 &lt;code&gt;Inf/NaN&lt;/code&gt; → 减小 &lt;code&gt;scale&lt;/code&gt;（更保守）。&lt;/p&gt;
&lt;p&gt;混合精度训练：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float16
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
iter=100000 loss=0.00813820119947195 cuda_mem=17707008Bytes
iter=200000 loss=0.0030544002074748278 cuda_mem=17707008Bytes
iter=300000 loss=0.0017446494894102216 cuda_mem=17707008Bytes
iter=400000 loss=0.0011826277477666736 cuda_mem=17707008Bytes
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;不用混合精度训练：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float32
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
iter=100000 loss=0.008329853415489197 cuda_mem=17742848Bytes
iter=200000 loss=0.003189701121300459 cuda_mem=17742848Bytes
iter=300000 loss=0.0018277550116181374 cuda_mem=17742848Bytes
iter=400000 loss=0.001241791993379593 cuda_mem=17742848Bytes
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到，使用AMP混合精度训练所占用的显存明显更低！&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/distri_trainning/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（十）：tracker.py详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-10</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-10</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:51:00 GMT</pubDate><content:encoded>&lt;p&gt;项目的主线已经差不多介绍完毕了，接下来都是介绍一些支线代码。在前面的&lt;code&gt;obj_detect.py&lt;/code&gt;代码中，可以看到里面引入了一个目标跟踪tracker的代码。所以本节就来了解一下&lt;code&gt;tracker.py&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;目标跟踪在本项目中的作用：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;在多目标场景中，目标跟踪确保每个目标都有唯一的ID标识，这样机器人可以按照预定顺序（如从左到右）依次抓取目标，而不会混淆不同目标。这对于实现有序抓取任务非常重要。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;当目标被短暂遮挡或检测算法在某些帧中未能检测到目标时，跟踪器可以基于先前的状态预测目标位置，保持对目标的跟踪，直到目标重新被检测到。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;在obj_detect.py中，系统根据目标的x坐标（从左到右）对检测结果进行排序，这使得机器人可以按顺序抓取目标。跟踪器确保了这种排序的一致性，即使在连续帧之间目标位置略有变化。&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;这是一个非常经典的基于卡尔曼滤波（Kalman Filter）的多目标跟踪代码，结合了目标检测、运动预测、数据关联和轨迹管理。&lt;/p&gt;
&lt;p&gt;该代码系统两个主要类：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Track：表示一个单个目标的轨迹&lt;/li&gt;
&lt;li&gt;Tracker：管理所有 Track，处理新检测、匹配、创建/删除轨迹&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;接下来分开介绍这两个类：&lt;/p&gt;
&lt;h2&gt;Track&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class Track:
    def __init__(self, track_id, bbox):
        self.id = track_id
        self.kf = self.create_kalman_filter(bbox)
        self.bbox = bbox  ## 用于输出可视化
        self.age = 0
        self.time_since_update = 0

    def create_kalman_filter(self, bbox):
        x1, y1, x2, y2 = bbox
        cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
        w, h = x2 - x1, y2 - y1
        state = [cx, cy, w, h, 0, 0, 0, 0]

        kf = KalmanFilter(dim_x=8, dim_z=4)
        kf.F = np.eye(8)
        for i in range(4):
            kf.F[i, i+4] = 1  ## 位置-速度关联

        kf.H = np.eye(4, 8)  ## 只测量位置宽高
        kf.R *= 5      ## 观测噪声（保守信任观测）
        kf.P *= 1000    ## 初始协方差（对初始状态不自信）
        kf.Q[4:, 4:] *= 5  ## 从 10 减小到 5


        kf.x[:4] = np.array(state[:4]).reshape((4, 1))
        return kf

    def predict(self):
        self.kf.predict()
        self.age += 1
        self.time_since_update += 1
        self.bbox = self.get_bbox()
        return self.bbox

    def update(self, bbox):
        x1, y1, x2, y2 = bbox
        cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
        w, h = x2 - x1, y2 - y1
        z = np.array([cx, cy, w, h])
        self.kf.update(z)
        self.time_since_update = 0
        self.bbox = self.get_bbox()

    def get_bbox(self):
        cx, cy, w, h = self.kf.x[:4].flatten()
        x1 = cx - w / 2
        y1 = cy - h / 2
        x2 = cx + w / 2
        y2 = cy + h / 2
        return [x1, y1, x2, y2]
&lt;/code&gt;&lt;/pre&gt;
&lt;ol&gt;
&lt;li&gt;状态表示（State Vector）&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;state = [cx, cy, w, h, 0, 0, 0, 0]
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;前4维：位置和尺寸
&lt;ul&gt;
&lt;li&gt;cx, cy：包围框中心坐标&lt;/li&gt;
&lt;li&gt;w, h：宽度和高度&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;后4维：速度（导数）
&lt;ul&gt;
&lt;li&gt;vx, vy, vw, vh：中心点和宽高的变化率（初始为0）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这是一个8维状态空间：&lt;code&gt;[cx, cy, w, h, vx, vy, vw, vh]&lt;/code&gt;&lt;/p&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;状态转移矩阵 F&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;kf.F = np.eye(8)
for i in range(4):
    kf.F[i, i+4] = 1  ## 位置 += 速度 * dt
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这表示一个恒定速度模型（Constant Velocity Model）：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-10%2F1.png%3ForigWidth%3D530%26origHeight%3D232%26origFormat%3Dpng&amp;#x26;w=530&amp;#x26;h=232&amp;#x26;f=webp&quot; alt=&quot;$$
\begin{bmatrix}
cx \
cy \
w \
h \
vx \
vy \
vw \
vh \
\end{bmatrix}_{t+1}
\begin{bmatrix}
1 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; \Delta t &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 \
0 &amp;#x26; 1 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; \Delta t &amp;#x26; 0 &amp;#x26; 0 \
0 &amp;#x26; 0 &amp;#x26; 1 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; \Delta t &amp;#x26; 0 \
0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 1 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; \Delta t \
0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 1 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 \
0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 1 &amp;#x26; 0 &amp;#x26; 0 \
0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 1 &amp;#x26; 0 \
0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 0 &amp;#x26; 1 \
\end{bmatrix}
\cdot
\begin{bmatrix}
cx \
cy \
w \
h \
vx \
vy \
vw \
vh \
\end{bmatrix}_t
$$&quot;&gt;&lt;/p&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;观测矩阵H&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;kf.H = np.eye(4, 8)  ## 只测量前4个量
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;表示我们只能观测到：
$$ z = [c_x, c_y, w, h] $$
而速度是隐变量，只能通过滤波估计。&lt;/p&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;噪声协方差设置
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-10%2F2.png%3ForigWidth%3D680%26origHeight%3D188%26origFormat%3Dpng&amp;#x26;w=680&amp;#x26;h=188&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
补充一下：
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;kf.H —— 观测矩阵（Observation Matrix / Measurement Matrix）&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;含义：定义如何从系统状态 x 映射到观测值 z。&lt;/li&gt;
&lt;li&gt;数学公式：$$ z = H \cdot x + v$$
&lt;ul&gt;
&lt;li&gt;z: 实际观测值（例如检测到的bbox）&lt;/li&gt;
&lt;li&gt;x: 系统内部状态（如 [cx, cy, w, h, vx, vy, vw, vh]）&lt;/li&gt;
&lt;li&gt;v: 观测噪声&lt;/li&gt;
&lt;li&gt;H: 决定状态向量中哪些部分能被观测到&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;kf.R  —— 观测噪声协方差矩阵（Measurement Noise Covariance)&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;含义： 描述观测数据的不确定性或噪声大小。值越大，说明你越不信任观测值。
&lt;ul&gt;
&lt;li&gt;R 越大 → 滤波器更相信预测值（平滑更多）&lt;/li&gt;
&lt;li&gt;R 越小 → 滤波器更相信观测值（响应更快但可能抖动）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;kf.P  —— 状态误差协方差矩阵（Error Covariance Matrix）&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;含义：表示当前状态估计的不确定性程度。它是卡尔曼增益计算的核心。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;P[i,i]：第 i 个状态分量的方差（不确定性）&lt;/li&gt;
&lt;li&gt;P[i,j]：状态 i 和 j 之间的协方差（相关性）&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;code&gt;kf.P *= 1000&lt;/code&gt; 表示初始状态非常不确定，比如你只知道大致位置，但不知道精确值。这样滤波器会更快地接受观测数据来修正自己。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;kf.Q —— 过程噪声协方差矩阵（Process Noise Covariance）&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;含义：描述系统模型本身的不确定性，也就是“真实世界偏离你模型的程度”。
&lt;ul&gt;
&lt;li&gt;Q 越大 → 认为运动模型不准确（比如物体可能突然加速）&lt;/li&gt;
&lt;li&gt;Q 越小 → 认为物体运动很稳定（如匀速）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-10%2F3.png%3ForigWidth%3D774%26origHeight%3D229%26origFormat%3Dpng&amp;#x26;w=774&amp;#x26;h=229&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;predict() 方法：预测下一时刻状态&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def predict(self):
    self.kf.predict()
    self.age += 1
    self.time_since_update += 1
    self.bbox = self.get_bbox()
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;调用卡尔曼滤波的 预测步骤（Predict）&lt;/li&gt;
&lt;li&gt;更新轨迹年龄和未更新次数&lt;/li&gt;
&lt;li&gt;从状态向量恢复 bbox&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;6&quot;&gt;
&lt;li&gt;update(bbox) 方法：融合新观测&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;z = [cx, cy, w, h]  ## 新检测的中心和尺寸
self.kf.update(z)
self.time_since_update = 0
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;执行卡尔曼滤波的 更新步骤（Update）&lt;/li&gt;
&lt;li&gt;使用观测值修正预测值&lt;/li&gt;
&lt;li&gt;重置“未更新计数器”&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;7&quot;&gt;
&lt;li&gt;get_bbox()：从状态生成包围框&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;cx, cy, w, h = self.kf.x[:4].flatten()
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;将内部状态转换为标准 (x1, y1, x2, y2) 格式的边界框，用于可视化或下游任务。&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;Tracker&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class Tracker:
    def __init__(self, iou_threshold=0.05, max_age=10):
        self.tracks = []
        self.next_id = 0
        self.iou_threshold = iou_threshold
        self.max_age = max_age

    def update(self, detections):
        ## 1. 预测所有轨迹
        for track in self.tracks:
            track.predict()

        ## 2. 匹配
        matches, unmatched_tracks, unmatched_dets = self.match(detections)

        ## 3. 更新匹配的轨迹
        for t_idx, d_idx in matches:
            self.tracks[t_idx].update(detections[d_idx])

        ## 4. 初始化新的轨迹
        for d_idx in unmatched_dets:
            self.tracks.append(Track(self.next_id, detections[d_idx]))
            self.next_id += 1

        ## 5. 移除太旧的轨迹
        self.tracks = [t for t in self.tracks if t.time_since_update &amp;#x3C; self.max_age]

        return [{&apos;id&apos;: t.id, &apos;bbox&apos;: list(map(int, t.bbox))} for t in self.tracks]


    def match(self, detections):
        iou_matrix = np.zeros((len(self.tracks), len(detections)), dtype=np.float32)

        for t, track in enumerate(self.tracks):
            if track.time_since_update &gt; self.max_age:  ## 太旧的就跳过匹配
                continue
            for d, det in enumerate(detections):
                iou_matrix[t, d] = self.iou(track.bbox, det)

        matched_indices = []
        unmatched_tracks = list(range(len(self.tracks)))
        unmatched_dets = list(range(len(detections)))

        used_dets = set()

        for t in range(len(self.tracks)):
            if track.time_since_update &gt; self.max_age:
                continue
            best_match = np.argmax(iou_matrix[t])
            if iou_matrix[t, best_match] &gt; self.iou_threshold and best_match not in used_dets:
                matched_indices.append((t, best_match))
                unmatched_tracks.remove(t)
                unmatched_dets.remove(best_match)
                used_dets.add(best_match)

        return matched_indices, unmatched_tracks, unmatched_dets


    def iou(self, boxA, boxB):
        xA = max(boxA[0], boxB[0])
        yA = max(boxA[1], boxB[1])
        xB = min(boxA[2], boxB[2])
        yB = min(boxA[3], boxB[3])
        interArea = max(0, xB - xA) * max(0, yB - yA)
        boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
        boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
        return interArea / float(boxAArea + boxBArea - interArea + 1e-6)
&lt;/code&gt;&lt;/pre&gt;
&lt;ol&gt;
&lt;li&gt;初始化&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def __init__(self, iou_threshold=0.05, max_age=10):
    self.tracks = []        ## 当前所有轨迹
    self.next_id = 0        ## 下一个分配的 ID
    self.iou_threshold     ## 匹配阈值
    self.max_age           ## 轨迹最大存活时间
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;
&lt;p&gt;update(detections) 主流程&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;步骤1&lt;/strong&gt;：预测所有现有轨迹&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;for track in self.tracks:
    track.predict()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;所有轨迹都向前“走一步”，预测当前位置。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;步骤2&lt;/strong&gt;：匹配 (match)&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;构建 IOU 矩阵：每个轨迹与每个检测之间的交并比&lt;/li&gt;
&lt;li&gt;使用贪心匹配策略：
&lt;ul&gt;
&lt;li&gt;对每个轨迹，找 IOU 最大的检测&lt;/li&gt;
&lt;li&gt;如果 &gt; iou_threshold 且未被占用 → 匹配成功&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;步骤3&lt;/strong&gt;：更新匹配的轨迹&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;for t_idx, d_idx in matches:
    self.tracks[t_idx].update(detections[d_idx])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;用最新检测更新对应轨迹的卡尔曼滤波器&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;步骤4&lt;/strong&gt;：创建新轨迹&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;for d_idx in unmatched_dets:
    self.tracks.append(Track(self.next_id, detections[d_idx]))
    self.next_id += 1
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;未匹配的检测 → 可能是新出现的目标 → 创建新 Track&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;步骤5&lt;/strong&gt;：删除过期轨迹&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.tracks = [t for t in self.tracks if t.time_since_update &amp;#x3C; self.max_age]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;如果一个轨迹长时间未匹配（time_since_update &gt;= max_age），说明目标已消失 → 删除&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;工作流程：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;每一帧输入：
     ↓
[检测列表 detections] 
     ↓
  Tracker.update()
     ├── 所有轨迹 predict() → 预测当前位置
     ├── 计算 IOU 矩阵
     ├── 匹配（轨迹 ↔ 检测）
     ├── 匹配成功 → update()（融合观测）
     ├── 未匹配检测 → 创建新轨迹
     └── 删除太久未更新的轨迹
     ↓
[输出：带 ID 的稳定轨迹列表]
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（九）：u5re_gripper.h详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-9</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-9</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:50:00 GMT</pubDate><content:encoded>&lt;p&gt;在前面的cpp操作文件中可以看到都引用了头文件&lt;code&gt;u5re_gripper.h&lt;/code&gt;，那么我们来详细了解一下这个头文件。&lt;/p&gt;
&lt;p&gt;按照惯例，先贴代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;using GripperCommand = control_msgs::action::GripperCommand;
using GoalHandleGripperCommand = rclcpp_action::ClientGoalHandle&amp;#x3C;GripperCommand&gt;;

class UR5eGripper : public rclcpp::Node {
public:
  explicit UR5eGripper(const rclcpp::NodeOptions &amp;#x26;options);
  void init();

  void get_target_pose_list(std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; &amp;#x26;target_pose_list);
  void get_joint_target_positions(
      moveit::planning_interface::MoveGroupInterfacePtr move_group,
      const std::vector&amp;#x3C;double&gt; &amp;#x26;target_pose, const std::string &amp;#x26;reference_frame,
      std::vector&amp;#x3C;double&gt; &amp;#x26;joint_target_positions);
  bool plan_and_execute(const std::vector&amp;#x3C;double&gt; &amp;#x26;target_pose);
  bool grasp(double gripper_position);
  void get_cube_pose(const std::string &amp;#x26;from_frame, const std::string &amp;#x26;to_frame,
                    std::vector&amp;#x3C;double&gt; &amp;#x26;cube_pose);
  void go_to_ready_position();

private:
  void goal_response_callback(const GoalHandleGripperCommand::SharedPtr &amp;#x26;goal_handle);
  void feedback_callback(GoalHandleGripperCommand::SharedPtr,
                        const std::shared_ptr&amp;#x3C;const GripperCommand::Feedback&gt; feedback);
  void result_callback(const GoalHandleGripperCommand::WrappedResult &amp;#x26;result);
  void str_list_2_double_list(const std::vector&amp;#x3C;std::string&gt; &amp;#x26;str_list,
                              std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; &amp;#x26;double_list);

  std::shared_ptr&amp;#x3C;moveit::planning_interface::MoveGroupInterface&gt; move_group_;
  moveit::planning_interface::PlanningSceneInterface planning_scene_interface_;
  rclcpp_action::Client&amp;#x3C;GripperCommand&gt;::SharedPtr gripper_action_client_;
  rclcpp_action::Client&amp;#x3C;GripperCommand&gt;::SendGoalOptions send_goal_options_;
  std::unique_ptr&amp;#x3C;tf2_ros::Buffer&gt; tf_buffer_;
  std::shared_ptr&amp;#x3C;tf2_ros::TransformListener&gt; tf_listener_;

  std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; target_pose_list_;
  std::string gripper_action_name_ = &quot;/gripper_controller/gripper_cmd&quot;;
  const std::string PLANNING_GROUP = &quot;ur_manipulator&quot;;
};

&lt;/code&gt;&lt;/pre&gt;
&lt;ol&gt;
&lt;li&gt;类定义：UR5eGripper : public rclcpp::Node&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;class UR5eGripper : public rclcpp::Node {
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;含义：UR5eGripper 是一个继承自 rclcpp::Node 的类&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;构造函数：explicit UR5eGripper(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;explicit UR5eGripper(const rclcpp::NodeOptions &amp;#x26;options);
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是一个构造函数。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;explicit：防止隐式类型转换（安全编程习惯）。&lt;/li&gt;
&lt;li&gt;参数：接收一个 NodeOptions 对象，允许在创建节点时配置选项（如自动声明参数、上下文、QoS 等）。&lt;/li&gt;
&lt;li&gt;作用：初始化这个 ROS 2 节点，比如设置节点名称（通常在 main() 中传入）。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;init() 函数&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void init();
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;作用：初始化类内部组件。&lt;/li&gt;
&lt;li&gt;为什么不在构造函数中完成？因为有些资源（如 MoveGroupInterface）需要在节点完全启动后才能安全创建。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;get_target_pose_list(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void get_target_pose_list(std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; &amp;#x26;target_pose_list);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：获取预设的“放置目标位姿”列表。&lt;/li&gt;
&lt;li&gt;参数：输出参数，返回一组目标位姿。&lt;/li&gt;
&lt;li&gt;每个位姿格式：[x, y, z, roll, pitch, yaw]（单位：米 + 弧度）&lt;/li&gt;
&lt;li&gt;用途：用于将物体放到不同位置（比如 target_pose_list[0] 是第一个放置点）&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;get_joint_target_positions(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void get_joint_target_positions(
      moveit::planning_interface::MoveGroupInterfacePtr move_group,
      const std::vector&amp;#x3C;double&gt; &amp;#x26;target_pose, const std::string &amp;#x26;reference_frame,
      std::vector&amp;#x3C;double&gt; &amp;#x26;joint_target_positions);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：给定一个末端执行器的目标位姿（target_pose），求解对应的关节角度目标值（逆运动学）。&lt;/li&gt;
&lt;li&gt;参数说明：
&lt;ul&gt;
&lt;li&gt;move_group：MoveIt 的运动组接口。&lt;/li&gt;
&lt;li&gt;target_pose：目标位姿 [x, y, z, roll, pitch, yaw]。&lt;/li&gt;
&lt;li&gt;reference_frame：参考坐标系（如 &quot;base_link&quot;）。&lt;/li&gt;
&lt;li&gt;joint_target_positions：输出的关节角（弧度）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;用途：为后续轨迹规划提供输入。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;6&quot;&gt;
&lt;li&gt;plan_and_execute(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;bool plan_and_execute(const std::vector&amp;#x3C;double&gt; &amp;#x26;target_pose);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：规划并执行一条到达目标位姿的轨迹。&lt;/li&gt;
&lt;li&gt;参数：目标位姿 [x, y, z, roll, pitch, yaw]。&lt;/li&gt;
&lt;li&gt;返回值：true 表示成功，false 表示失败（如无解、碰撞、超时）。&lt;/li&gt;
&lt;li&gt;内部流程：
&lt;ul&gt;
&lt;li&gt;调用 get_joint_target_positions 求解目标关节角。&lt;/li&gt;
&lt;li&gt;使用 MoveGroupInterface 进行轨迹规划。&lt;/li&gt;
&lt;li&gt;执行规划好的轨迹（发送到控制器）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;这是最核心的运动控制函数之一。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;7&quot;&gt;
&lt;li&gt;grasp(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;bool grasp(double gripper_position);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：控制夹爪开合。&lt;/li&gt;
&lt;li&gt;参数：gripper_position —— 目标开合度（单位：米或归一化值，取决于夹爪驱动方式）。
&lt;ul&gt;
&lt;li&gt;例如：0.0 表示闭合，1 表示完全打开。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;实现方式：通过 ROS 2 Action 发送目标命令。&lt;/li&gt;
&lt;li&gt;返回值：是否成功发送命令或等待完成。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;8&quot;&gt;
&lt;li&gt;get_cube_pose(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void get_cube_pose(const std::string &amp;#x26;from_frame, const std::string &amp;#x26;to_frame,
                    std::vector&amp;#x3C;double&gt; &amp;#x26;cube_pose);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：通过 TF（Transform）系统获取两个坐标系之间的相对位姿。&lt;/li&gt;
&lt;li&gt;参数：
&lt;ul&gt;
&lt;li&gt;from_frame：源坐标系（如 &quot;base_link&quot;）&lt;/li&gt;
&lt;li&gt;to_frame：目标坐标系（如 &quot;cube1&quot;）&lt;/li&gt;
&lt;li&gt;cube_pose：输出位姿 [x, y, z, roll, pitch, yaw]&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;实现方式：
&lt;ul&gt;
&lt;li&gt;使用 tf_buffer_-&gt;lookupTransform(...) 获取 geometry_msgs::TransformStamped&lt;/li&gt;
&lt;li&gt;将四元数转换为欧拉角（roll, pitch, yaw）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;用途：获取立方体在机器人基座坐标系下的位置，用于抓取。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;9&quot;&gt;
&lt;li&gt;go_to_ready_position()&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void go_to_ready_position();
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：让机器人回到一个预设的“安全起始位置”（Ready Pose / Home Pose）。&lt;/li&gt;
&lt;li&gt;实现方式：
&lt;ul&gt;
&lt;li&gt;可能是通过关节目标（Joint Target）直接移动。&lt;/li&gt;
&lt;li&gt;或者调用 plan_and_execute 到某个预设位姿。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;目的：
&lt;ul&gt;
&lt;li&gt;避免奇异点&lt;/li&gt;
&lt;li&gt;防止碰撞&lt;/li&gt;
&lt;li&gt;作为每次操作前的初始化姿态&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;接下来介绍private私有函数，这些是类内部使用的资源和回调函数。&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;goal_response_callback(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void goal_response_callback(const GoalHandleGripperCommand::SharedPtr &amp;#x26;goal_handle);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;作用：当向夹爪 Action Server 发送目标后，服务器返回是否接受该目标。&lt;/li&gt;
&lt;li&gt;典型行为：
&lt;ul&gt;
&lt;li&gt;如果目标被接受，继续监听反馈。&lt;/li&gt;
&lt;li&gt;如果被拒绝，记录警告或重试。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;用途：异步处理 Action 请求的响应。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;feedback_callback(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void feedback_callback(GoalHandleGripperCommand::SharedPtr,
                        const std::shared_ptr&amp;#x3C;const GripperCommand::Feedback&gt; feedback);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;作用：在夹爪运动过程中，持续接收反馈信息。&lt;/li&gt;
&lt;li&gt;反馈内容可能包括：
&lt;ul&gt;
&lt;li&gt;当前开合度&lt;/li&gt;
&lt;li&gt;施加的力&lt;/li&gt;
&lt;li&gt;运动状态（moving, stalled）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;用途：监控夹爪状态，可用于检测是否夹紧物体。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;result_callback(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void result_callback(const GoalHandleGripperCommand::WrappedResult &amp;#x26;result);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;作用：Action 执行完成后，接收最终结果。&lt;/li&gt;
&lt;li&gt;结果状态可能为：
&lt;ul&gt;
&lt;li&gt;SUCCEEDED&lt;/li&gt;
&lt;li&gt;ABORTED&lt;/li&gt;
&lt;li&gt;CANCELED&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;用途：判断夹爪动作是否成功完成。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;str_list_2_double_list(...)&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;void str_list_2_double_list(const std::vector&amp;#x3C;std::string&gt; &amp;#x26;str_list,
                              std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; &amp;#x26;double_list);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;功能：将字符串列表（如参数服务器读取的字符串数组）转换为双精度浮点数二维数组。&lt;/li&gt;
&lt;li&gt;典型用途：
&lt;ul&gt;
&lt;li&gt;从 YAML 文件读取 [&quot;0.1,0.2,0.3&quot;, &quot;0.4,0.5,0.6&quot;]&lt;/li&gt;
&lt;li&gt;解析成 [[0.1,0.2,0.3], [0.4,0.5,0.6]]&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;使用场景：加载 target_pose_list_ 时，把参数字符串转为数值。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;最后介绍私有成员变量&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::shared_ptr&amp;#x3C;moveit::planning_interface::MoveGroupInterface&gt; move_group_;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;MoveIt 的核心接口，用于与机器人运动规划器交互。&lt;/li&gt;
&lt;li&gt;提供 setPoseTarget, setJointValueTarget, plan(), move() 等方法。&lt;/li&gt;
&lt;li&gt;对应的运动组是 &quot;ur_manipulator&quot;。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;moveit::planning_interface::PlanningSceneInterface planning_scene_interface_;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;用于与 规划场景（Planning Scene） 交互。&lt;/li&gt;
&lt;li&gt;可添加/删除障碍物、设置碰撞对象等。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;rclcpp_action::Client&amp;#x3C;GripperCommand&gt;::SharedPtr gripper_action_client_;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;夹爪动作客户端，用于向夹爪控制器发送命令。&lt;/li&gt;
&lt;li&gt;Action 类型：GripperCommand（通常是 control_msgs::action::GripperCommand）&lt;/li&gt;
&lt;li&gt;目标话题：/gripper_controller/gripper_cmd&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;rclcpp_action::Client&amp;#x3C;GripperCommand&gt;::SendGoalOptions send_goal_options_;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;配置 Action 客户端发送目标时的回调函数（即上面三个回调）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;在 init() 中会设置：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;send_goal_options_.goal_response = std::bind(&amp;#x26;UR5eGripper::goal_response_callback, this, _1);
send_goal_options_.feedback = std::bind(&amp;#x26;UR5eGripper::feedback_callback, this, _1, _2);
send_goal_options_.result = std::bind(&amp;#x26;UR5eGripper::result_callback, this, _1);
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::unique_ptr&amp;#x3C;tf2_ros::Buffer&gt; tf_buffer_;
std::shared_ptr&amp;#x3C;tf2_ros::TransformListener&gt; tf_listener_;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;TF2 系统组件：
&lt;ul&gt;
&lt;li&gt;tf_buffer_：存储所有坐标变换的历史数据。&lt;/li&gt;
&lt;li&gt;tf_listener_：自动订阅 /tf 和 /tf_static 话题，填充 buffer。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;用于 get_cube_pose() 中查询坐标变换。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; target_pose_list_;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;存储所有预设的放置目标位姿。&lt;/li&gt;
&lt;li&gt;通过 get_target_pose_list() 填充。&lt;/li&gt;
&lt;li&gt;每个元素是一个 6 维向量 [x, y, z, roll, pitch, yaw]。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::string gripper_action_name_ = &quot;/gripper_controller/gripper_cmd&quot;;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;定义夹爪 Action 的话题名称。&lt;/li&gt;
&lt;li&gt;可通过参数修改（更灵活）。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;const std::string PLANNING_GROUP = &quot;ur_manipulator&quot;;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;定义 MoveIt 中的运动组名称。&lt;/li&gt;
&lt;li&gt;在 ur5e_moveit_config 中定义，通常包含所有机械臂关节。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;总结，这个类的作用：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-9%2F1.png%3ForigWidth%3D604%26origHeight%3D276%26origFormat%3Dpng&amp;#x26;w=604&amp;#x26;h=276&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;捋一下&lt;code&gt;demo.launch.py、 demo.cpp、 ur5e_gripper.cpp&lt;/code&gt; 和 &lt;code&gt;ur5e_gripper.h&lt;/code&gt; 四个文件之间的关系。&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;整体架构关系&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;这四个文件构成了一个完整的ROS 2机器人控制系统的不同层次：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;demo.launch.py (启动配置层)
    ↓
demo.cpp (应用主程序层)
    ↓
ur5e_gripper.cpp (功能实现层)
    ↓
ur5e_gripper.h (接口定义层)
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;详细调用关系分析：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;demo.launch.py → demo.cpp
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;demo.launch.py是启动文件，负责配置和启动demo.cpp编译生成的可执行文件：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;Node(
    package=&apos;ur5e_gripper_control&apos;,    ## 指定功能包
    executable=&apos;demo&apos;,                 ## 指定可执行文件名
    name=&apos;demo_node&apos;,                  ## 节点名称
    parameters=[                       ## 传递参数
        ## ...
    ], 
    output=&apos;screen&apos;
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;当运行ros2 launch ur5e_gripper_control demo.launch.py时，系统会：&lt;/p&gt;
&lt;p&gt;a. 加载ur5e_gripper_control包
b. 查找并执行名为demo的可执行文件
c. 将配置参数传递给该可执行文件&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;demo.cpp → ur5e_gripper.cpp
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;demo.cpp是应用程序的主入口，它使用UR5eGripper类实现具体功能：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;// 创建UR5eGripper对象实例
auto node = std::make_shared&amp;#x3C;UR5eGripper&gt;(node_options);

// 调用UR5eGripper类的方法
node-&gt;init();
node-&gt;get_target_pose_list(target_pose_list);
node-&gt;get_cube_pose(from_frame, to_frame_list[i], cube_pose);
node-&gt;plan_and_execute(cube_pose_list[i]);
node-&gt;grasp(0.36);
node-&gt;go_to_ready_position();
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;demo.cpp中调用了ur5e_gripper.cpp中实现的多个方法来完成机器人控制任务。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;ur5e_gripper.cpp ↔ ur5e_gripper.h
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;ur5e_gripper.h是UR5eGripper类的头文件，定义了类的接口：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;class UR5eGripper : public rclcpp::Node {
public:
  explicit UR5eGripper(const rclcpp::NodeOptions &amp;#x26;options);
  void init();
  void get_target_pose_list(std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; &amp;#x26;target_pose_list);
  bool plan_and_execute(const std::vector&amp;#x3C;double&gt; &amp;#x26;target_pose);
  bool grasp(double gripper_position);
  // ... 其他公共方法
};
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;ur5e_gripper.cpp是实现文件，包含了这些方法的具体实现。&lt;/p&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;编译和链接关系&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;在CMakeLists.txt中定义了编译规则：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;## 将demo.cpp和ur5e_gripper.cpp编译成一个可执行文件
add_executable(demo src/demo.cpp src/ur5e_gripper.cpp)

## 链接所需的库
ament_target_dependencies(demo
  rclcpp moveit_ros_planning_interface tf2 
  moveit_core moveit_ros_planning control_msgs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这表明demo.cpp和ur5e_gripper.cpp被编译成同一个可执行文件demo。&lt;/p&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;数据流向&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;a. 配置数据流向：&lt;code&gt;target_pose_list.yaml → demo.launch.py → demo.cpp → UR5eGripper构造函数&lt;/code&gt;
b. 控制指令流向：&lt;code&gt;demo.cpp → UR5eGripper方法 → MoveIt/ROS 2系统 → 机器人硬件&lt;/code&gt;
c. 反馈数据流向：&lt;code&gt;机器人硬件 → ROS 2系统 → UR5eGripper方法 → demo.cpp&lt;/code&gt;&lt;/p&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;各文件职责总结
a. demo.launch.py：负责配置节点启动参数，是系统的入口点
b. demo.cpp：实现应用逻辑，协调各个功能模块完成抓取任务
c. ur5e_gripper.cpp：实现具体的机器人控制功能，如运动规划、夹爪控制等
d. ur5e_gripper.h：定义UR5eGripper类的公共接口，为上层应用提供调用入口&lt;/li&gt;
&lt;/ol&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（八）：start_grasp详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-8</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-8</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:47:00 GMT</pubDate><content:encoded>&lt;p&gt;前面已经介绍了仿真软件的启动，接下来开始介绍抓取的代码start_grasp.launch.py&lt;/p&gt;
&lt;h2&gt;第一步，先看generate_launch_description代码&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    ## 生成启动描述的主函数
    ## 该函数定义了所有需要声明的参数和启动设置
    
    ## 声明参数列表（当前为空，可根据需要添加参数）
    declared_arguments = []

    ## 返回LaunchDescription对象，包含所有声明的参数和启动设置函数
    return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里没什么内容，跳过&lt;/p&gt;
&lt;h2&gt;第二步，看launch_setup函数代码&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def launch_setup(context, *args, **kwargs):
    ## 启动设置函数，用于配置和启动抓取演示相关的节点

    ## 包含抓取演示的launch文件
    ## 该launch文件来自ur5e_gripper_control包，用于演示机器人抓取功能
    grasp_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;ur5e_gripper_control&quot;), &quot;/launch&quot;, &quot;/demo.launch.py&quot;]
        ),
    )

    ## 组装需要启动的节点列表
    nodes_to_launch = [
        grasp_launch,
    ]

    ## 返回待启动的节点列表
    return nodes_to_launch
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里调用了ur5e_gripper_control下面的launch目录下的demo.launch.py代码。该部分的代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    ## 构造机器人运动学配置文件路径
    ## 该文件包含机器人逆运动学求解器的配置参数
    robot_description_kinematics = PathJoinSubstitution(
        [FindPackageShare(&quot;ur5e_gripper_moveit_config&quot;), &quot;config&quot;, &quot;kinematics.yaml&quot;]
    )

    ## 获取目标位姿列表配置文件路径
    ## 该文件定义了机器人需要到达的一系列目标位姿
    target_pose_list = os.path.join(
        get_package_share_directory(&apos;ur5e_gripper_control&apos;),
        &apos;config&apos;, 
        &apos;target_pose_list.yaml&apos;
    )

    ## 返回LaunchDescription对象，包含需要启动的节点列表
    return LaunchDescription([
        ## 启动演示节点，用于控制机器人执行预定义的抓取任务
        Node(
            package=&apos;ur5e_gripper_control&apos;,     ## 节点所属的功能包名称
            executable=&apos;demo&apos;,                  ## 要执行的可执行文件名称
            name=&apos;demo_node&apos;,                   ## 节点名称
            parameters=[{                       ## 节点参数列表
                    &quot;use_sim_time&quot;:True,        ## 使用仿真时间
                },
                robot_description_kinematics,   ## 机器人运动学参数文件
                target_pose_list                ## 目标位姿列表配置文件
            ], 
            output=&apos;screen&apos;                     ## 输出方式：打印到屏幕
        ),
    ])
&lt;/code&gt;&lt;/pre&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;首先加载了机器人逆运动学求解器的配置参数，这个参数文件是在通过使用Setup assistant生成的文件，可以参考&lt;a href=&quot;https://blog.csdn.net/qq_38880380/article/details/97390527?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522b660a363b117b4e9472936aa51c95acf%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&amp;#x26;request_id=b660a363b117b4e9472936aa51c95acf&amp;#x26;biz_id=0&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-97390527-null-null.142%5Ev102%5Epc_search_result_base7&amp;#x26;utm_term=ur5e%20setup%20assitant&amp;#x26;spm=1018.2226.3001.4187&quot;&gt;博客&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;这里定义了机器人需要到达的一系列目标位姿，也就是将6个正方体移动的目标位置&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;上述两个文件是通过创建节点调用demo.cpp，从而传入到demo.cpp文件中，可以理解demo.launch.py是一个中介。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;返回需要启动的节点列表和参数&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;接下来重点来了！&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;Node(
     package=&apos;ur5e_gripper_control&apos;,           ## 节点所属的功能包
      executable=&apos;demo&apos;,                        ## 可执行文件名
      name=&apos;demo_node&apos;,                         ## 节点名称
      parameters=[{                             ## 节点参数列表
              &quot;use_sim_time&quot;: True,             ## 使用仿真时间
          },
          robot_description_kinematics,         ## 机器人运动学参数
          target_pose_list                      ## 目标位姿列表
      ], 
      output=&apos;screen&apos;                           ## 输出方式：打印到屏幕
        )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个配置指定了:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;功能包: ur5e_gripper_control&lt;/li&gt;
&lt;li&gt;可执行文件: demo&lt;/li&gt;
&lt;li&gt;节点名称: demo_node&lt;/li&gt;
&lt;/ul&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;构建配置：CMakeLists.txt
在ur5e_gripper_control包的CMakeLists.txt中，定义了demo可执行文件的构建规则：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;add_executable(demo src/demo.cpp src/ur5e_gripper.cpp)
ament_target_dependencies(demo
  rclcpp moveit_ros_planning_interface tf2 
  moveit_core moveit_ros_planning control_msgs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这表明demo可执行文件由demo.cpp和ur5e_gripper.cpp两个源文件编译而成&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;可执行文件安装
同样在CMakeLists.txt中，定义了demo可执行文件的安装路径&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;install(TARGETS demo
  DESTINATION lib/${PROJECT_NAME}
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这表示demo可执行文件将被安装到lib/ur5e_gripper_control/目录下。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;实际执行流程
a. ROS 2 launch系统加载并执行start_grasp.launch.py
b. start_grasp.launch.py包含并执行demo.launch.py
c. demo.launch.py创建并启动demo_node节点
d. ROS 2系统在功能包ur5e_gripper_control的lib目录中查找名为demo的可执行文件
e. 系统找到并执行demo可执行文件（由demo.cpp编译生成）
f. demo.cpp中的main函数开始执行，控制机械臂完成抓取和放置任务&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;详细介绍demo.cpp文件，代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;/**
 * @brief UR5e机器人抓取演示主函数
 * 
 * 该函数实现了完整的机器人抓取和放置流程，包括：
 * 1. 初始化ROS 2节点和机器人控制接口
 * 2. 获取预设的放置位置列表
 * 3. 通过TF获取立方体的位置信息
 * 4. 控制机器人依次抓取立方体并放置到指定位置
 * 5. 完成任务后回到准备位置
 * 
 * @param argc 命令行参数数量
 * @param argv 命令行参数数组
 * @return int 程序退出状态码
 */
int main(int argc, char **argv) {
  // 初始化ROS 2客户端库
  rclcpp::init(argc, argv);
  
  // 创建节点选项并设置自动声明参数
  rclcpp::NodeOptions node_options;
  node_options.automatically_declare_parameters_from_overrides(true);
  
  // 创建UR5e机器人控制节点实例
  auto node = std::make_shared&amp;#x3C;UR5eGripper&gt;(node_options);
  
  // 初始化机器人控制接口
  node-&gt;init();

  // 创建单线程执行器并添加节点
  rclcpp::executors::SingleThreadedExecutor executor;
  executor.add_node(node);
  
  // 启动独立线程运行执行器
  std::thread([&amp;#x26;executor]() { executor.spin(); }).detach();

  // 获取预设的目标放置位置列表
  std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; target_pose_list;
  node-&gt;get_target_pose_list(target_pose_list);
  
  // 设置坐标变换参考帧
  std::string from_frame = &quot;base_link&quot;;
  
  // 定义需要抓取的立方体坐标系名称列表
  std::vector&amp;#x3C;std::string&gt; to_frame_list = {&quot;cube1&quot;, &quot;cube2&quot;, &quot;cube3&quot;, &quot;cube4&quot;, &quot;cube5&quot;, &quot;cube6&quot;};

  // 存储立方体抓取位置的列表
  std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; cube_pose_list;
  
  // 遍历所有立方体，获取它们的位置信息并调整为抓取位置
  for (size_t i = 0; i &amp;#x3C; to_frame_list.size(); i++) {
    std::vector&amp;#x3C;double&gt; cube_pose;
    
    // 通过TF获取立方体相对于base_link的位姿
    node-&gt;get_cube_pose(from_frame, to_frame_list[i], cube_pose);
    
    // 检查是否成功获取位姿信息
    if (cube_pose.empty()) {
      RCLCPP_WARN(rclcpp::get_logger(&quot;demo4&quot;), &quot;Failed to get pose for %s, skipping&quot;, to_frame_list[i].c_str());
      continue;
    }

    // 调整抓取位置，确保机器人从立方体上方抓取
    cube_pose[0] -= 0.012 ;  // 微调X坐标
    cube_pose[1] += 0.01;    // 微调Y坐标
    //cube_pose[2] += 0.14;  // 注释掉的Z轴调整
    cube_pose[2] += 0.14;    // Z轴增加0.14米，确保从上方抓取
    cube_pose[3] = 0.0;      // 设置roll角度为0
    cube_pose[4] = M_PI;     // 设置pitch角度为π(180度)，使夹爪向下
    cube_pose[5] = 0.0;      // 设置yaw角度为0
    
    // 打印调整后的立方体抓取位置
    RCLCPP_INFO(rclcpp::get_logger(&quot;demo4&quot;), &quot;Adjusted cube pose for %s: x=%f, y=%f, z=%f&quot;,
                to_frame_list[i].c_str(), cube_pose[0], cube_pose[1], cube_pose[2]);
    
    // 将调整后的抓取位置添加到列表中
    cube_pose_list.push_back(cube_pose);
  }

  // 循环执行抓取和放置操作，最多处理6个立方体
  for (size_t i = 0; i &amp;#x3C; std::min&amp;#x3C;size_t&gt;(6, cube_pose_list.size()); i++) {
    // 控制机器人移动到立方体抓取位置
    bool grasp_success = node-&gt;plan_and_execute(cube_pose_list[i]);
    
    // 如果移动失败，则跳过当前立方体
    if (!grasp_success) {
      continue;
    }
    
    // 闭合夹爪抓取立方体，0.36为夹爪闭合位置
    node-&gt;grasp(0.36);
    
    // 等待1秒确保夹爪完全闭合
    rclcpp::sleep_for(std::chrono::seconds(1));

    // 如果还有预设的放置位置，则执行放置操作
    if (i &amp;#x3C; target_pose_list.size()) {
      // 控制机器人移动到预设的放置位置
      bool place_success = node-&gt;plan_and_execute(target_pose_list[i]);
      
      // 如果移动成功，则执行放置操作
      if (place_success) {
        // 打开夹爪释放立方体，0为夹爪完全打开位置
        node-&gt;grasp(0);
        
        // 等待1秒确保立方体稳定放置
        rclcpp::sleep_for(std::chrono::seconds(1));
      }
    }
	
  }
  
  // 完成所有抓取放置任务后，回到准备位置
  node-&gt;go_to_ready_position();
  
  // 关闭ROS 2客户端库
  rclcpp::shutdown();
  
  // 程序正常退出
  return 0;
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来进行解释代码：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;初始化ROS2&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;rclcpp::init(argc, argv);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;作用：初始化 ROS 2 客户端库。&lt;/li&gt;
&lt;li&gt;所有 ROS 2 C++ 程序都必须先调用此函数。&lt;/li&gt;
&lt;li&gt;它会解析命令行参数（如 --ros-args），启动通信中间件（如 DDS）&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;设置节点选项&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;rclcpp::NodeOptions node_options;
node_options.automatically_declare_parameters_from_overrides(true);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个 NodeOptions 对象，用于配置节点行为。&lt;/li&gt;
&lt;li&gt;automatically_declare_parameters_from_overrides(true)：
&lt;ul&gt;
&lt;li&gt;表示允许通过命令行直接传入参数（如 -p some_param:=value），而无需在代码中预先声明。&lt;/li&gt;
&lt;li&gt;提高灵活性，便于调试和配置。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;创建机器人控制节点&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;auto node = std::make_shared&amp;#x3C;UR5eGripper&gt;(node_options);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;使用智能指针创建一个 UR5eGripper 类的实例。&lt;/li&gt;
&lt;li&gt;UR5eGripper 是一个继承自 rclcpp::Node 的类，封装了 UR5e 机械臂和夹爪的所有控制逻辑。&lt;/li&gt;
&lt;li&gt;传入 node_options 配置参数自动声明功能。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;初始化机器人控制接口&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;node-&gt;init();
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;调用自定义类的初始化函数&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;创建单线程执行器并运行&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;rclcpp::executors::SingleThreadedExecutor executor;
executor.add_node(node);
std::thread([&amp;#x26;executor]() { executor.spin(); }).detach();
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;让 ROS 2 节点开始持续处理回调函数（如订阅消息、服务请求、定时器等）&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;6&quot;&gt;
&lt;li&gt;获取预设的放置位置列表&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; target_pose_list;
node-&gt;get_target_pose_list(target_pose_list);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;target_pose_list 是一个二维数组，存储多个“放置点”的位姿。&lt;/li&gt;
&lt;li&gt;每个位姿是一个长度为 6 的 double 数组：[x, y, z, roll, pitch, yaw]&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;7&quot;&gt;
&lt;li&gt;设置坐标参考帧&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::string from_frame = &quot;base_link&quot;;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;base_link 是 UR5e 机器人的基座坐标系，是整个机器人系统的参考原点。&lt;/li&gt;
&lt;li&gt;所有其他坐标系（如 tool0, camera_link, cube1）都是相对于 base_link 定义的&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;8&quot;&gt;
&lt;li&gt;定义要抓取的立方体名称列表&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::vector&amp;#x3C;std::string&gt; to_frame_list = {&quot;cube1&quot;, &quot;cube2&quot;, &quot;cube3&quot;, &quot;cube4&quot;, &quot;cube5&quot;, &quot;cube6&quot;};
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;视觉系统已经检测出最多 6 个立方体，并通过 TF 发布了它们的坐标系：cube1, cube2, ..., cube6&lt;/li&gt;
&lt;li&gt;这些坐标系是通过前面 Python 节点广播的（如 camera_link → cube1）&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;9&quot;&gt;
&lt;li&gt;存储立方体抓取位置的列表&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;std::vector&amp;#x3C;std::vector&amp;#x3C;double&gt;&gt; cube_pose_list;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;用于保存所有立方体调整后的抓取位姿（包含位置和姿态）。&lt;/li&gt;
&lt;li&gt;后续会遍历这个列表，依次抓取每个立方体。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;10&quot;&gt;
&lt;li&gt;核心循环 1：获取并调整立方体位置&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;for (size_t i = 0; i &amp;#x3C; to_frame_list.size(); i++) {
    std::vector&amp;#x3C;double&gt; cube_pose;
    
    // 通过TF获取立方体相对于base_link的位姿
    node-&gt;get_cube_pose(from_frame, to_frame_list[i], cube_pose);
    
    if (cube_pose.empty()) {
      RCLCPP_WARN(rclcpp::get_logger(&quot;demo4&quot;), &quot;Failed to get pose for %s, skipping&quot;, to_frame_list[i].c_str());
      continue;
    }

    // 调整抓取位置
    cube_pose[0] -= 0.012 ;  // 微调X坐标
    cube_pose[1] += 0.01;    // 微调Y坐标
    cube_pose[2] += 0.14;    // Z轴抬高0.14米，确保从上方垂直抓取
    cube_pose[3] = 0.0;      // roll = 0°
    cube_pose[4] = M_PI;     // pitch = 180°（夹爪向下）
    cube_pose[5] = 0.0;      // yaw = 0°

    RCLCPP_INFO(...);  // 打印日志
    
    cube_pose_list.push_back(cube_pose);
}
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;node-&gt;get_cube_pose(...)
&lt;ul&gt;
&lt;li&gt;该函数内部使用 tf2_ros::Buffer 和 TransformListener 查询 from_frame 到 to_frame 的变换。&lt;/li&gt;
&lt;li&gt;例如：查询 base_link → cube1 的平移和旋转。&lt;/li&gt;
&lt;li&gt;返回一个 6 维向量：[x, y, z, roll, pitch, yaw]&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;11&quot;&gt;
&lt;li&gt;核心循环 2：抓取与放置&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;for (size_t i = 0; i &amp;#x3C; std::min&amp;#x3C;size_t&gt;(6, cube_pose_list.size()); i++) {
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;最多处理 6 个立方体。&lt;/li&gt;
&lt;li&gt;使用 std::min 防止越界（比如只检测到 3 个立方体）。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;步骤一：移动到抓取位置&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;bool grasp_success = node-&gt;plan_and_execute(cube_pose_list[i]);
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;plan_and_execute()：
&lt;ul&gt;
&lt;li&gt;使用 MoveIt 2 规划一条从当前位姿到目标位姿的无碰撞路径。&lt;/li&gt;
&lt;li&gt;执行该路径，控制机械臂运动。&lt;/li&gt;
&lt;li&gt;返回 true 表示成功，false 表示规划失败（如路径被阻挡）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;步骤二：闭合夹爪抓取&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;node-&gt;grasp(0.36);
rclcpp::sleep_for(std::chrono::seconds(1));
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;grasp(0.36)：控制夹爪闭合到 0.36，刚好夹住立方体。&lt;/li&gt;
&lt;li&gt;sleep_for(1s)：等待夹爪完全闭合，确保夹紧。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;步骤三：移动到放置位置&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;if (i &amp;#x3C; target_pose_list.size()) {
    bool place_success = node-&gt;plan_and_execute(target_pose_list[i]);
    if (place_success) {
        node-&gt;grasp(0);  // 打开夹爪
        rclcpp::sleep_for(std::chrono::seconds(1));
    }
}
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;移动到第 i 个预设放置点。&lt;/li&gt;
&lt;li&gt;成功后打开夹爪（grasp(0)），释放立方体。&lt;/li&gt;
&lt;li&gt;等待 1 秒确保物体稳定落下。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;12&quot;&gt;
&lt;li&gt;任务完成：返回准备位置&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;node-&gt;go_to_ready_position();
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;调用预设的“准备位姿”（Home Position），通常是安全抬高的位置。&lt;/li&gt;
&lt;li&gt;避免下次启动时发生碰撞。&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;13&quot;&gt;
&lt;li&gt;结束程序&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cpp&quot;&gt;rclcpp::shutdown();
return 0;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;整体流程&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;[程序启动]
     ↓
rclcpp::init() → 初始化 ROS 2
     ↓
创建 UR5eGripper 节点
     ↓
node-&gt;init() → 初始化 MoveIt 和夹爪
     ↓
启动 executor.spin()（后台线程）
     ↓
获取放置点列表 target_pose_list
     ↓
循环遍历 cube1~cube6：
   ├─ 通过 TF 查询 cube_i 相对于 base_link 的位置
   ├─ 调整为抓取位姿（抬高 Z，设置 pitch=π）
   └─ 存入 cube_pose_list
     ↓
再次循环抓取每个立方体：
   ├─ plan_and_execute(抓取位姿)
   ├─ grasp(0.36)  // 抓
   ├─ plan_and_execute(放置位姿)
   └─ grasp(0)     // 放
     ↓
go_to_ready_position()  // 回家
     ↓
rclcpp::shutdown() → 关闭系统
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;补充&lt;/strong&gt;
target_pose_list.yaml内容如下时：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-yaml&quot;&gt;/**:
  ros__parameters:
    target_pose_list:
    ##   x, y, z, roll, pitch, yaw, grasp
      - &quot;0.5, -0.52, 0.2, 0, 3.14, 0.0&quot; ## 2
      - &quot;0.5, -0.45, 0.2, 0, 3.14, 0.0&quot; ## 1
      - &quot;0.5, -0.38, 0.28, 0, 3.14, 0.0&quot; ## 4
      - &quot;0.5, 0.38, 0.2, 0, 3.14, 0.0&quot; ## 3
      - &quot;0.5, 0.45, 0.28, 0, 3.14, 0.0&quot; ## 5
      - &quot;0.5, 0.52, 0.35, 0, 3.14, 0.0&quot; ## 6
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;执行的效果如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-8%2F1.png%3ForigWidth%3D906%26origHeight%3D480%26origFormat%3Dpng&amp;#x26;w=906&amp;#x26;h=480&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
浅浅将x的0.5都改为0.3，执行的效果如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-8%2F2.png%3ForigWidth%3D712%26origHeight%3D512%26origFormat%3Dpng&amp;#x26;w=712&amp;#x26;h=512&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
效果不太好，掉了一个~&lt;/p&gt;
&lt;p&gt;抓取期间生成的日志内容如下，ROS2日志的解析参考&lt;a href=&quot;https://blog.csdn.net/qq_41336087/article/details/133745100?ops_request_misc=&amp;#x26;request_id=&amp;#x26;biz_id=102&amp;#x26;utm_term=ROS2%E7%9A%84%E6%97%A5%E5%BF%97%E4%BF%A1%E6%81%AF&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-1-133745100.142%5Ev102%5Epc_search_result_base7&amp;#x26;spm=1018.2226.3001.4187&quot;&gt;博客&lt;/a&gt;：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-8%2F3.png%3ForigWidth%3D1170%26origHeight%3D427%26origFormat%3Dpng&amp;#x26;w=1170&amp;#x26;h=427&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
举个例子：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-8%2F4.png%3ForigWidth%3D1558%26origHeight%3D1159%26origFormat%3Dpng&amp;#x26;w=1558&amp;#x26;h=1159&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（七）：simulation文件小结</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-7</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-7</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:45:00 GMT</pubDate><content:encoded>&lt;p&gt;在前面，我们已经从头到尾了解了simulation.launch.py文件的执行过程，接下来需要对该部分进行一个小结。&lt;/p&gt;
&lt;p&gt;首先查看一下节点，使用&lt;code&gt;ros2 node list&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-7%2F1.png%3ForigWidth%3D489%26origHeight%3D306%26origFormat%3Dpng&amp;#x26;w=489&amp;#x26;h=306&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
查看rqt_grapgh，运行仿真环境后输入&lt;code&gt;rqt_graph&lt;/code&gt;即可
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-7%2F2.png%3ForigWidth%3D4966%26origHeight%3D1555%26origFormat%3Dpng&amp;#x26;w=4966&amp;#x26;h=1555&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;未完待续....&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（六）：seg_and_det文件详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-6</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-6</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:42:00 GMT</pubDate><content:encoded>&lt;p&gt;接下来对视觉处理模块的代码seg_and_det.launch.py进行详解
代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    ## 返回启动描述列表，包含三个视觉处理节点
    return LaunchDescription([
        ## 物体检测节点 - 负责识别和定位场景中的物体
        Node(
            package=&apos;vision&apos;,           ## 所属功能包
            executable=&apos;obj_detect&apos;,    ## 可执行文件名
            name=&apos;obj_detect&apos;,          ## 节点名称
            output=&apos;screen&apos;             ## 输出方式：打印到屏幕
        ),
        ## 检测坐标变换节点 - 负责计算检测到的物体相对于机器人基座的坐标变换
        Node(
            package=&apos;vision&apos;,           ## 所属功能包
            executable=&apos;det_tf&apos;,        ## 可执行文件名
            name=&apos;det_tf&apos;,              ## 节点名称
            output=&apos;screen&apos;             ## 输出方式：打印到屏幕
        ),
        ## 点云处理节点 - 负责处理深度相机获取的点云数据
        Node(
            package=&apos;vision&apos;,                   ## 所属功能包
            executable=&apos;point_cloud_processor&apos;, ## 可执行文件名
            name=&apos;point_cloud_processor&apos;,       ## 节点名称
            output=&apos;screen&apos;                     ## 输出方式：打印到屏幕
        )
    ])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个launch文件定义了三个视觉处理节点：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;obj_detect节点 - 负责物体检测功能，识别和定位场景中的物体&lt;/li&gt;
&lt;li&gt;det_tf节点 - 负责检测坐标变换，计算检测到的物体相对于机器人基座的坐标变换关系&lt;/li&gt;
&lt;li&gt;point_cloud_processor节点 - 负责点云处理，处理深度相机获取的点云数据
每个节点都来自vision功能包，并且都设置为将输出打印到屏幕，方便调试和监控。&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;obj_detect节点详解&lt;/h2&gt;
&lt;p&gt;由上述代码可知，该节点是vision包下的obj_detect.py文件，内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class ObjDetect(Node):
    &quot;&quot;&quot;
    目标检测节点类
    
    该类继承自ROS 2 Node类，实现基于YOLO模型的目标检测功能，
    包括图像订阅、目标检测、深度估计、目标追踪和结果发布等功能。
    &quot;&quot;&quot;

    def __init__(self):
        &quot;&quot;&quot;
        初始化目标检测节点
        
        设置参数、加载模型、初始化订阅者和发布者等。
        &quot;&quot;&quot;
        super().__init__(&quot;obj_detect&quot;)
        self.declare_parameter(&quot;model_path&quot;, &quot;/home/whisper/ros2_ws/src/ros2_moveit2_ur5e_grasp/src/vision/vision/yolov11/models/best.pt&quot;)
        self.declare_parameter(&quot;depth_topic&quot;, &quot;/depth_registered/image_rect&quot;)
        self.declare_parameter(&quot;image_topic&quot;, &quot;/color/image_raw&quot;)
        self.declare_parameter(&quot;cam_info_topic&quot;, &quot;/color/camera_info&quot;)
        self.declare_parameter(&quot;view_image&quot;, True)
        self.declare_parameter(&quot;publish_result&quot;, True)

        ## 参数读取
        model_path = self.get_parameter(&quot;model_path&quot;).value
        self.view_img = self.get_parameter(&quot;view_image&quot;).value
        self.publish_result = self.get_parameter(&quot;publish_result&quot;).value

        ## 模型加载
        self.model = YOLO(model_path)
        self.names = self.model.names

        self.bridge = CvBridge()
        self.tracker = Tracker()

        self.depth_instrinsic_inv = np.eye(3)
        self.depth = None
        self.image = None

        ## ROS2 订阅与发布
        self.create_subscription(Image, self.get_parameter(&quot;depth_topic&quot;).value, self.depth_callback, 10)
        self.create_subscription(Image, self.get_parameter(&quot;image_topic&quot;).value, self.image_callback, 10)
        self.create_subscription(CameraInfo, self.get_parameter(&quot;cam_info_topic&quot;).value, self.caminfo_callback, 10)
        self.detection_pub = self.create_publisher(Detection2DArray, &quot;/detection&quot;, 10)
        self.image_pub = self.create_publisher(Image, &quot;/detect_track&quot;, 10)

        self.timer = self.create_timer(0.1, self.timer_callback)
        self.get_logger().info(&quot;YOLOv11 ObjDetect Node Initialized.&quot;)

    def caminfo_callback(self, msg):
        &quot;&quot;&quot;
        相机信息回调函数
        
        获取相机内参矩阵并计算其逆矩阵，用于后续的深度估计和3D坐标计算。
        
        Args:
            msg (CameraInfo): 相机信息消息，包含相机内参矩阵K
        &quot;&quot;&quot;
        K = np.array(msg.k).reshape(3, 3)
        self.depth_instrinsic_inv = np.linalg.inv(K)

    def depth_callback(self, msg):
        &quot;&quot;&quot;
        深度图像回调函数
        
        保存接收到的深度图像消息，供后续处理使用。
        
        Args:
            msg (Image): 深度图像消息
        &quot;&quot;&quot;
        self.depth = msg

    def image_callback(self, msg):
        &quot;&quot;&quot;
        彩色图像回调函数
        
        保存接收到的彩色图像消息，供后续处理使用。
        
        Args:
            msg (Image): 彩色图像消息
        &quot;&quot;&quot;
        self.image = msg

    def timer_callback(self):
        &quot;&quot;&quot;
        定时器回调函数
        
        定期执行目标检测、追踪和结果发布等核心功能。
        &quot;&quot;&quot;
        if self.depth is None or self.image is None:
            return

        ## 获取图像
        dep = self.bridge.imgmsg_to_cv2(self.depth, desired_encoding=&apos;passthrough&apos;)
        img0 = self.bridge.imgmsg_to_cv2(self.image, desired_encoding=&apos;bgr8&apos;)

        ## 模型推理
        results = self.model.predict(img0, verbose=False)
        res = results[0]
        boxes = res.boxes.xyxy.cpu().numpy()
        confs = res.boxes.conf.cpu().numpy()
        clses = res.boxes.cls.cpu().numpy()

        ## 目标追踪
        detections = [[int(x1), int(y1), int(x2), int(y2)] for x1, y1, x2, y2 in boxes]
        tracks = self.tracker.update(detections)

        ## 构造 Detection2DArray
        detection_result = Detection2DArray()
        detections_list = []

        for (x1, y1, x2, y2), conf, cls_id in zip(boxes, confs, clses):
            cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
            detection = Detection2D()
            detection.id = self.names[int(cls_id)]
            detection.bbox.center.position.x = float(cx)
            detection.bbox.center.position.y = float(cy)
            detection.bbox.size_x = float(x2 - x1)
            detection.bbox.size_y = float(y2 - y1)

            ## 深度估计
            Z = dep[cy, cx] * 1e-3
            if Z &amp;#x3C;= 0 or np.isnan(Z):  ## 防止无效的深度值
                continue
            uv1 = np.array([cx, cy, 1.0])
            XYZ = self.depth_instrinsic_inv @ uv1 * Z

            obj_hypothesis = ObjectHypothesisWithPose()
            obj_hypothesis.hypothesis.class_id = self.names[int(cls_id)]
            obj_hypothesis.hypothesis.score = float(conf)
            obj_hypothesis.pose.pose.position.x = float(XYZ[0])
            obj_hypothesis.pose.pose.position.y = float(XYZ[1])
            obj_hypothesis.pose.pose.position.z = float(XYZ[2])
            detection.results.append(obj_hypothesis)

            detections_list.append(detection)

            ## 可视化
            if self.view_img:
                cv2.circle(img0, (cx, cy), 3, (0, 0, 255), -1)

        ## 按 x 坐标（从左到右）排序目标检测结果
        sorted_detections = sorted(detections_list, key=lambda det: det.bbox.center.position.x)
        detection_result.detections.extend(sorted_detections)

        ## 画 track ID
        sorted_tracks = sorted(tracks, key=lambda track: track[&apos;bbox&apos;][0])

        for visual_idx, track in enumerate(sorted_tracks):
            x1, y1, x2, y2 = track[&apos;bbox&apos;]
            ## 显示 ID（绿色）+ 排序编号（蓝色）
            cv2.putText(img0, f&quot;Rank {visual_idx+1}&quot;, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
            ## 画框
            cv2.rectangle(img0, (x1, y1), (x2, y2), (0, 255, 0), 2)

        ## 发布检测结果和追踪图像
        if self.view_img:
            image_msg = self.bridge.cv2_to_imgmsg(img0, encoding=&apos;bgr8&apos;)
            image_msg.header.stamp = self.get_clock().now().to_msg()
            self.image_pub.publish(image_msg)

        if self.publish_result:
            detection_result.header.stamp = self.get_clock().now().to_msg()
            self.detection_pub.publish(detection_result)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;参数的设置与读取&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.declare_parameter(&quot;model_path&quot;, &quot;/home/whisper/ros2_ws/src/ros2_moveit2_ur5e_grasp/src/vision/vision/yolov11/models/best.pt&quot;)
self.declare_parameter(&quot;depth_topic&quot;, &quot;/depth_registered/image_rect&quot;)
self.declare_parameter(&quot;image_topic&quot;, &quot;/color/image_raw&quot;)
self.declare_parameter(&quot;cam_info_topic&quot;, &quot;/color/camera_info&quot;)
self.declare_parameter(&quot;view_image&quot;, True)
self.declare_parameter(&quot;publish_result&quot;, True)
## 参数读取
model_path = self.get_parameter(&quot;model_path&quot;).value
self.view_img = self.get_parameter(&quot;view_image&quot;).value
self.publish_result = self.get_parameter(&quot;publish_result&quot;).value
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是一系列 &lt;code&gt;self.declare_parameter(...)&lt;/code&gt; 调用，每个都为当前 ROS 2 节点声明一个可配置的参数。
语法：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.declare_parameter(...)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;name: 参数的名称（字符串）&lt;/li&gt;
&lt;li&gt;value: 参数的默认值。如果启动时没有通过外部方式指定该参数，就使用这个值&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.model = YOLO(model_path)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;从model_path加载 YOLO 目标检测模型&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.names = self.model.names
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;获取模型能够识别的类别名称列表。例如：{0: &apos;apple&apos;, 1: &apos;cup&apos;, 2: &apos;box&apos;}。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.bridge = CvBridge()
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个 CvBridge 对象，用于在 ROS 图像消息 和 OpenCV 图像（NumPy 数组） 之间进行转换。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.tracker = Tracker()
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个目标跟踪器（Tracker） 对象。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.depth_instrinsic_inv = np.eye(3)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;初始化一个 3x3 单位矩阵，用于存放相机内参的逆矩阵。&lt;/li&gt;
&lt;li&gt;要将图像中的 2D 像素坐标 (u, v) 和深度值 d 转换为 3D 空间坐标 (x, y, z)，需要使用相机内参矩阵 K，公式如下：
$$ z = d $$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$ x = \frac{(u - c_x) \cdot z}{f_x} $$&lt;/p&gt;
&lt;p&gt;$$ y = \frac{(v - c_y) \cdot z}{f_y} $$&lt;/p&gt;
&lt;p&gt;转换成矩阵形式：&lt;/p&gt;
&lt;p&gt;$$
\begin{bmatrix} x \ y \ z \end{bmatrix} = K^{-1} \begin{bmatrix} u \cdot z \ v \cdot z \ z \end{bmatrix}
$$&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.create_subscription(Image, self.get_parameter(&quot;depth_topic&quot;).value, self.depth_callback, 10)
self.create_subscription(Image, self.get_parameter(&quot;image_topic&quot;).value, self.image_callback, 10)
self.create_subscription(CameraInfo, self.get_parameter(&quot;cam_info_topic&quot;).value, self.caminfo_callback, 10)
self.detection_pub = self.create_publisher(Detection2DArray, &quot;/detection&quot;, 10)
self.image_pub = self.create_publisher(Image, &quot;/detect_track&quot;, 10)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码是用于 ROS 2 节点初始化的一部分，主要负责订阅话题、发布话题。
例如&lt;code&gt;self.create_subscription(Image, self.get_parameter(&quot;depth_topic&quot;).value, self.depth_callback, 10)&lt;/code&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;作用：创建一个订阅者，监听深度图像话题。&lt;/li&gt;
&lt;li&gt;参数：
&lt;ul&gt;
&lt;li&gt;Image：消息类型，表示接收到的是一个图像。&lt;/li&gt;
&lt;li&gt;self.get_parameter(&quot;depth_topic&quot;).value：从节点参数中获取深度图像的话题名称。&lt;/li&gt;
&lt;li&gt;self.depth_callback：当接收到新消息时调用的回调函数。&lt;/li&gt;
&lt;li&gt;10：队列大小（QoS），限制未处理消息的最大数量。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;在介绍话题发布者之前，需要先介绍&lt;code&gt;self.timer = self.create_timer(0.1, self.timer_callback)&lt;/code&gt;，因为通过这段代码不断调用self.timer_callback函数生成话题发布所需要的信息。这行代码的作用是：创建一个定时器（Timer），让它每隔 0.1 秒自动调用一次 self.timer_callback 函数。&lt;/p&gt;
&lt;p&gt;接下来介绍self.timer_callback 函数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def timer_callback(self):
        &quot;&quot;&quot;
        定时器回调函数
        
        定期执行目标检测、追踪和结果发布等核心功能。
        该函数是整个节点的核心处理逻辑，负责图像处理、目标检测、
        三维位置计算、目标追踪和结果可视化等任务。
        &quot;&quot;&quot;
        ## 检查是否有有效的深度图像和彩色图像数据
        ## 如果任一图像为空，则跳过本次处理
        if self.depth is None or self.image is None:
            return

        ## 获取图像数据
        ## 将ROS图像消息转换为OpenCV格式
        ## dep: 深度图像，保持原始编码格式
        ## img0: 彩色图像，转换为BGR8格式
        dep = self.bridge.imgmsg_to_cv2(self.depth, desired_encoding=&apos;passthrough&apos;)
        img0 = self.bridge.imgmsg_to_cv2(self.image, desired_encoding=&apos;bgr8&apos;)

        ## 使用YOLO模型进行目标检测推理
        ## verbose=False: 不输出详细日志信息
        ## results: 包含检测结果的列表
        ## res: 第一个(也是唯一一个)检测结果
        ## boxes: 检测到的目标边界框坐标 [x1, y1, x2, y2]
        ## confs: 每个检测结果的置信度分数
        ## clses: 每个检测结果的类别ID
        results = self.model.predict(img0, verbose=False)
        res = results[0]
        boxes = res.boxes.xyxy.cpu().numpy()
        confs = res.boxes.conf.cpu().numpy()
        clses = res.boxes.cls.cpu().numpy()

        ## 目标追踪处理
        ## 将检测到的边界框坐标转换为整数格式
        ## 调用追踪器更新目标位置并获取追踪结果
        detections = [[int(x1), int(y1), int(x2), int(y2)] for x1, y1, x2, y2 in boxes]
        tracks = self.tracker.update(detections)

        ## 构造Detection2DArray消息，用于发布检测结果
        ## detection_result: 最终发布的检测结果数组
        ## detections_list: 临时存储检测结果的列表
        detection_result = Detection2DArray()
        detections_list = []

        ## 遍历所有检测到的目标
        for (x1, y1, x2, y2), conf, cls_id in zip(boxes, confs, clses):
            ## 计算边界框中心点坐标
            cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
            
            ## 创建Detection2D对象并设置基本属性
            detection = Detection2D()
            detection.id = self.names[int(cls_id)]  ## 设置目标类别名称
            ## 设置边界框中心点坐标
            detection.bbox.center.position.x = float(cx)
            detection.bbox.center.position.y = float(cy)
            ## 设置边界框尺寸
            detection.bbox.size_x = float(x2 - x1)
            detection.bbox.size_y = float(y2 - y1)

            ## 深度估计和3D位置计算
            ## 从深度图像中获取中心点的深度值并转换为米
            Z = dep[cy, cx] * 1e-3
            ## 检查深度值是否有效(大于0且非NaN)
            if Z &amp;#x3C;= 0 or np.isnan(Z):  ## 防止无效的深度值
                continue
            
            ## 构造齐次坐标 [cx, cy, 1]
            uv1 = np.array([cx, cy, 1.0])
            ## 使用相机内参逆矩阵和深度值计算3D坐标
            ## XYZ = K^(-1) * [u, v, 1]^T * Z
            XYZ = self.depth_instrinsic_inv @ uv1 * Z

            ## 创建目标假设对象并设置属性
            obj_hypothesis = ObjectHypothesisWithPose()
            obj_hypothesis.hypothesis.class_id = self.names[int(cls_id)]  ## 设置类别ID
            obj_hypothesis.hypothesis.score = float(conf)  ## 设置置信度分数
            ## 设置3D位置坐标
            obj_hypothesis.pose.pose.position.x = float(XYZ[0])
            obj_hypothesis.pose.pose.position.y = float(XYZ[1])
            obj_hypothesis.pose.pose.position.z = float(XYZ[2])
            ## 将目标假设添加到检测结果中
            detection.results.append(obj_hypothesis)

            ## 将处理后的检测结果添加到列表中
            detections_list.append(detection)

            ## 可视化处理：在图像上绘制中心点
            if self.view_img:
                ## 在目标中心点绘制红色圆点
                cv2.circle(img0, (cx, cy), 3, (0, 0, 255), -1)

        ## 按 x 坐标（从左到右）排序目标检测结果
        ## 便于机器人按顺序抓取目标
        sorted_detections = sorted(detections_list, key=lambda det: det.bbox.center.position.x)
        detection_result.detections.extend(sorted_detections)

        ## 对追踪结果按边界框的x坐标排序
        ## 保证追踪ID与检测结果的顺序一致
        sorted_tracks = sorted(tracks, key=lambda track: track[&apos;bbox&apos;][0])

        ## 在图像上绘制追踪结果
        for visual_idx, track in enumerate(sorted_tracks):
            x1, y1, x2, y2 = track[&apos;bbox&apos;]
            ## 显示排序编号（蓝色），位置在边界框上方
            cv2.putText(img0, f&quot;Rank {visual_idx+1}&quot;, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
            ## 绘制绿色边界框
            cv2.rectangle(img0, (x1, y1), (x2, y2), (0, 255, 0), 2)

        ## 发布检测结果和追踪图像
        ## 如果启用了图像可视化，则发布处理后的图像
        if self.view_img:
            ## 将OpenCV图像转换回ROS图像消息
            image_msg = self.bridge.cv2_to_imgmsg(img0, encoding=&apos;bgr8&apos;)
            ## 设置消息时间戳为当前时间
            image_msg.header.stamp = self.get_clock().now().to_msg()
            ## 发布图像消息
            self.image_pub.publish(image_msg)

        ## 如果启用了结果发布，则发布检测结果
        if self.publish_result:
            ## 设置检测结果数组的时间戳
            detection_result.header.stamp = self.get_clock().now().to_msg()
            ## 发布检测结果
            self.detection_pub.publish(detection_result)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个函数是整个目标检测节点的核心处理逻辑，主要功能包括：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;数据检查：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;检查深度图像和彩色图像是否有效&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;图像处理：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;使用CvBridge将ROS图像消息转换为OpenCV格式
分别处理深度图像和彩色图像&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;目标检测：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;使用YOLO模型对彩色图像进行推理
提取检测结果中的边界框、置信度和类别信息&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;目标追踪：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;将检测结果传递给追踪器进行目标跟踪&lt;/li&gt;
&lt;li&gt;获取追踪结果&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;三维位置计算：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;从深度图像中获取目标中心点的深度值&lt;/li&gt;
&lt;li&gt;结合相机内参逆矩阵计算目标的3D坐标&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;结果排序：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;按照x坐标对检测结果进行排序，便于机器人按顺序抓取&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;可视化处理：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;在图像上绘制检测中心点和追踪边界框&lt;/li&gt;
&lt;li&gt;添加排序编号标识&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;结果发布：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;发布处理后的可视化图像&lt;/li&gt;
&lt;li&gt;发布目标检测和三维位置信息&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;经过上述代码处理之后的结果如下：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://i-blog.csdnimg.cn/direct/2676101aaa0f4f84b6b61eda9b978ea9.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;
接下来介绍发布话题&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.detection_pub = self.create_publisher(Detection2DArray, &quot;/detection&quot;, 10)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个发布者（Publisher），用于将目标检测的结果（如物体类别、边界框、置信度等）发布到名为 /detection 的话题上。## det_tf节点详解&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.image_pub = self.create_publisher(Image, &quot;/detect_track&quot;, 10)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个发布者，用于将带有检测的图像发布到 /detect_track 话题上。&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;det_tf节点详解&lt;/h2&gt;
&lt;p&gt;代码内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class DetTF(Node):
    &quot;&quot;&quot;
    检测结果TF变换发布节点类
    
    该类继承自ROS 2 Node类，负责订阅目标检测结果，
    并将检测到的目标位置发布为TF变换，便于机器人系统
    进行坐标变换和路径规划。
    &quot;&quot;&quot;

    def __init__(self):
        &quot;&quot;&quot;
        初始化检测结果TF变换发布节点
        
        创建目标检测结果订阅者和TF广播器。
        &quot;&quot;&quot;
        super().__init__(&apos;det_tf&apos;)
        ## 创建目标检测结果订阅者
        self.subscription = self.create_subscription(
            Detection2DArray,                 ## 消息类型：二维检测结果数组
            &apos;detection&apos;,                      ## 订阅话题名称
            self.det_callback,                ## 回调函数
            10)                               ## 队列大小
        self.subscription  ## prevent unused variable warning

        ## 创建TF广播器，用于发布坐标变换
        self.tf_broadcaster = TransformBroadcaster(self)

    def det_callback(self, msg: Detection2DArray):
        &quot;&quot;&quot;
        目标检测结果回调函数
        
        处理接收到的目标检测结果，提取特定类别的目标位置信息，
        并按指定规则排序后发布为TF变换。
        
        Args:
            msg (Detection2DArray): 包含检测结果的消息
        &quot;&quot;&quot;
        ## 定义要处理的目标类别名称
        cls_name = &apos;cube&apos;
        ## 存储检测到的目标位置信息
        objs = []

        ## 遍历所有检测结果，筛选出指定类别的目标
        for det in msg.detections:
            if det.id == cls_name:
                ## 提取目标的3D位置信息
                pos = det.results[0].pose.pose.position
                objs.append(pos)

        ## 按照 x 坐标升序排序（从左到右）
        objs.sort(key=lambda p: p.x)

        ## 广播 TF，按排序后的位置命名
        for idx, pos in enumerate(objs, start=1):
            ## 创建坐标变换消息
            t = TransformStamped()
            ## 设置时间戳为当前时间
            t.header.stamp = self.get_clock().now().to_msg()
            ## 设置父坐标系
            t.header.frame_id = &apos;camera_color_optical_frame&apos;
            ## 设置子坐标系，按顺序命名
            t.child_frame_id = f&apos;cube{idx}&apos;

            ## 设置平移变换
            t.transform.translation.x = pos.x
            t.transform.translation.y = pos.y
            t.transform.translation.z = pos.z
            ## 设置旋转变换（单位四元数，表示无旋转）
            t.transform.rotation.w = 1.0
            t.transform.rotation.x = 0.0
            t.transform.rotation.y = 0.0
            t.transform.rotation.z = 0.0

            ## 发布坐标变换
            self.tf_broadcaster.sendTransform(t)


def main(args=None):
    &quot;&quot;&quot;
    主函数
    
    初始化ROS 2节点并启动检测结果TF变换发布功能。
    
    Args:
        args: 程序启动参数
    &quot;&quot;&quot;
    ## 初始化ROS 2客户端库
    rclpy.init(args=args)

    ## 创建DetTF节点实例
    det_tf = DetTF()

    try:
        ## 循环处理节点消息
        rclpy.spin(det_tf)
    except KeyboardInterrupt:
        ## 处理键盘中断（Ctrl+C）
        pass

    ## 销毁节点并关闭ROS 2客户端库
    det_tf.destroy_node()
    rclpy.shutdown()
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.subscription = self.create_subscription(
    Detection2DArray,                 ## 消息类型：二维检测结果数组
    &apos;detection&apos;,                      ## 订阅话题名称
    self.det_callback,                ## 回调函数
    10                                ## 队列大小
)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个订阅者（Subscriber），订阅话题detection，用于接收来自其他节点发布的 2D 目标检测结果。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.subscription  ## prevent unused variable warning
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;防止代码分析工具或编译器报 “未使用变量” 警告。&lt;/li&gt;
&lt;li&gt;在 Python 中，如果你只是把 create_subscription 的返回值赋给一个变量但没有读取或使用它，某些 Linter（如 pylint）会认为这是一个“未使用变量”，并发出警告。&lt;/li&gt;
&lt;li&gt;实际上，你必须保存这个订阅者对象，否则它会被 Python 的垃圾回收机制销毁，导致订阅失败。&lt;/li&gt;
&lt;li&gt;所以这行代码只是“读取”了一下 self.subscription，告诉工具：“我确实用了这个变量”，从而消除警告。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;self.tf_broadcaster = TransformBroadcaster(self)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;创建一个 TF2 广播器（Transform Broadcaster），用于向 ROS 2 的 TF 树（Transform Tree） 发布坐标变换。&lt;/li&gt;
&lt;li&gt;这个操作就可以是Base_link或Rviz知道这个物体的相对自己的具体位置&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;接下来介绍逻辑处理函数&lt;code&gt;det_callback&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def det_callback(self, msg: Detection2DArray):
        &quot;&quot;&quot;
        目标检测结果回调函数
        
        处理接收到的目标检测结果，提取特定类别的目标位置信息，
        并按指定规则排序后发布为TF变换。
        
        Args:
            msg (Detection2DArray): 包含检测结果的消息
        &quot;&quot;&quot;
        ## 定义要处理的目标类别名称
        cls_name = &apos;cube&apos;
        ## 存储检测到的目标位置信息
        objs = []

        ## 遍历所有检测结果，筛选出指定类别的目标
        for det in msg.detections:
            if det.id == cls_name:
                ## 提取目标的3D位置信息
                pos = det.results[0].pose.pose.position
                objs.append(pos)

        ## 按照 x 坐标升序排序（从左到右）
        objs.sort(key=lambda p: p.x)

        ## 广播 TF，按排序后的位置命名
        for idx, pos in enumerate(objs, start=1):
            ## 创建坐标变换消息
            t = TransformStamped()
            ## 设置时间戳为当前时间
            t.header.stamp = self.get_clock().now().to_msg()
            ## 设置父坐标系
            t.header.frame_id = &apos;camera_color_optical_frame&apos;
            ## 设置子坐标系，按顺序命名
            t.child_frame_id = f&apos;cube{idx}&apos;

            ## 设置平移变换
            t.transform.translation.x = pos.x
            t.transform.translation.y = pos.y
            t.transform.translation.z = pos.z
            ## 设置旋转变换（单位四元数，表示无旋转）
            t.transform.rotation.w = 1.0
            t.transform.rotation.x = 0.0
            t.transform.rotation.y = 0.0
            t.transform.rotation.z = 0.0

            ## 发布坐标变换
            self.tf_broadcaster.sendTransform(t)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将目标检测的结果位置信息发布为TF变换，值得注意的是这些Cube的位置信息是相对于&lt;code&gt;camera_color_optical_frame &lt;/code&gt;坐标系的。&lt;/p&gt;
&lt;p&gt;工作流程如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;接收到 Detection2DArray 消息
        ↓
筛选出类别为 &apos;cube&apos; 的检测结果
        ↓
提取每个 cube 的 3D 位置 (x, y, z)
        ↓
按 x 坐标从小到大排序（从左到右）
        ↓
为每个 cube 创建一个 TF 坐标系：cube1, cube2, ...
        ↓
广播这些坐标系相对于相机的变换
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;point_cloud_processor节点详解&lt;/h2&gt;
&lt;p&gt;对于这个代码，我目前感觉对我们的这个项目没有什么用（可能看的还不够深入，不知道在哪里使用了），但是还是要解释一下，代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class PointCloudProcessor(Node):
    &quot;&quot;&quot;
    点云处理器节点类
    
    该类继承自ROS 2 Node类，负责订阅原始点云数据，
    根据指定参数对点云进行过滤处理，并发布过滤后的点云数据。
    主要用于减少点云数据量，提高处理效率。
    &quot;&quot;&quot;

    def __init__(self):
        &quot;&quot;&quot;
        初始化点云处理器节点
        
        声明参数、创建点云订阅者和发布者。
        &quot;&quot;&quot;
        super().__init__(&apos;point_cloud_processor&apos;)
        ## 声明参数
        self.declare_parameter(&apos;proportion&apos;, 0.3)      ## 点云保留比例，默认保留30%
        self.declare_parameter(&apos;start_from_top&apos;, True) ## 是否从顶部开始截取点云，默认为True
        
        ## 创建点云订阅者，订阅原始点云数据
        self.subscription = self.create_subscription(
            PointCloud2,                     ## 消息类型：点云2
            &apos;/depth/points&apos;,                 ## 订阅话题名称
            self.listener_callback,          ## 回调函数
            10)                              ## 队列大小
        self.subscription  ## prevent unused variable warning

        ## 创建点云发布者，发布过滤后的点云数据
        self.publisher_ = self.create_publisher(PointCloud2, &apos;/depth/points_filtered&apos;, 10)

    def listener_callback(self, msg):
        &quot;&quot;&quot;
        点云数据回调函数
        
        处理接收到的点云数据，根据参数设置过滤点云，
        并发布过滤后的点云数据。
        
        Args:
            msg (PointCloud2): 接收到的原始点云数据消息
        &quot;&quot;&quot;
        ## 获取参数值
        proportion = self.get_parameter(&apos;proportion&apos;).value       ## 点云保留比例
        start_from_top = self.get_parameter(&apos;start_from_top&apos;).value  ## 是否从顶部开始截取

        ## 处理有组织的点云（高度大于1）
        if msg.height &gt; 1:  ## 有组织的点云
            max_height = msg.height                    ## 原始点云高度
            new_height = int(proportion * max_height)  ## 计算新高度
            ## 检查新高度是否有效
            if new_height &amp;#x3C; 1:
                self.get_logger().info(&apos;New height too small, not publishing&apos;)
                return
            
            ## 计算截取起始位置
            if start_from_top:
                ## 从顶部开始截取
                start = 0
            else:
                ## 从底部开始截取
                start = (msg.height - new_height) * msg.row_step
            
            ## 计算截取数据大小
            data_size = new_height * msg.row_step
            ## 截取点云数据
            new_data = msg.data[start : start + data_size]
            
            ## 创建新的 PointCloud2 消息
            new_msg = PointCloud2()
            new_msg.header = msg.header              ## 复用原始消息的头部信息
            new_msg.height = new_height              ## 设置新高度
            new_msg.width = msg.width                ## 保持宽度不变
            new_msg.fields = msg.fields              ## 复用原始字段定义
            new_msg.is_bigendian = msg.is_bigendian  ## 复用字节序设置
            new_msg.point_step = msg.point_step      ## 复用点步长
            new_msg.row_step = msg.row_step          ## 复用行步长
            new_msg.data = new_data                  ## 设置新数据
            new_msg.is_dense = msg.is_dense          ## 复用密集性标志
            
            ## 发布过滤后的点云
            self.publisher_.publish(new_msg)
        
        ## 处理无组织的点云（高度为1）
        elif msg.height == 1:  ## 无组织的点云
            max_width = msg.width                     ## 原始点云宽度
            new_width = int(proportion * max_width)   ## 计算新宽度
            ## 检查新宽度是否有效
            if new_width &amp;#x3C; 1:
                self.get_logger().info(&apos;New width too small, not publishing&apos;)
                return
            
            ## 计算截取起始位置
            if start_from_top:
                ## 从开始位置截取
                start = 0
            else:
                ## 从末尾开始截取
                start = (msg.width - new_width) * msg.point_step
            
            ## 计算截取数据大小
            data_size = new_width * msg.point_step
            ## 截取点云数据
            new_data = msg.data[start : start + data_size]
            
            ## 创建新的 PointCloud2 消息
            new_msg = PointCloud2()
            new_msg.header = msg.header               ## 复用原始消息的头部信息
            new_msg.height = 1                        ## 保持高度为1
            new_msg.width = new_width                 ## 设置新宽度
            new_msg.fields = msg.fields               ## 复用原始字段定义
            new_msg.is_bigendian = msg.is_bigendian   ## 复用字节序设置
            new_msg.point_step = msg.point_step       ## 复用点步长
            new_msg.row_step = new_width * msg.point_step  ## 计算并设置行步长
            new_msg.data = new_data                   ## 设置新数据
            new_msg.is_dense = msg.is_dense           ## 复用密集性标志
            
            ## 发布过滤后的点云
            self.publisher_.publish(new_msg)
        
        ## 处理异常情况
        else:
            ## 点云高度小于1，无法处理
            self.get_logger().warn(&apos;Point cloud has height &amp;#x3C;1, cannot process&apos;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个文件实现了一个ROS 2节点，主要功能是对点云数据进行过滤处理。详细说明如下：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;PointCloudProcessor类：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;继承自ROS 2 Node类，实现点云处理功能&lt;/li&gt;
&lt;li&gt;初始化时声明参数、创建点云订阅者和发布者&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;初始化方法：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;声明两个参数：proportion（点云保留比例，默认0.3）和start_from_top（是否从顶部开始截取，默认True）&lt;/li&gt;
&lt;li&gt;创建对/depth/points话题的订阅者，用于接收原始点云数据&lt;/li&gt;
&lt;li&gt;创建对/depth/points_filtered话题的发布者，用于发布过滤后的点云数据&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;点云处理逻辑：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;listener_callback：处理接收到的点云数据&lt;/li&gt;
&lt;li&gt;根据点云的组织形式（有组织或无组织）分别处理：
&lt;ul&gt;
&lt;li&gt;有组织点云（height &gt; 1）：按行进行截取，保持宽度不变，调整高度&lt;/li&gt;
&lt;li&gt;无组织点云（height = 1）：按点进行截取，调整宽度&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;根据start_from_top参数决定是从起始位置还是末尾位置开始截取&lt;/li&gt;
&lt;li&gt;根据proportion参数确定截取的比例&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;点云数据处理：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;计算新的点云尺寸（高度或宽度）&lt;/li&gt;
&lt;li&gt;计算截取的起始位置和数据大小&lt;/li&gt;
&lt;li&gt;从原始点云数据中截取相应部分&lt;/li&gt;
&lt;li&gt;构建新的PointCloud2消息并发布&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;原始点云数据效果：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-6%2F1.png%3ForigWidth%3D1322%26origHeight%3D915%26origFormat%3Dpng&amp;#x26;w=1322&amp;#x26;h=915&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;过滤后的点云效果：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-6%2F2.png%3ForigWidth%3D1329%26origHeight%3D978%26origFormat%3Dpng&amp;#x26;w=1329&amp;#x26;h=978&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;这些节点共同构成了机器人视觉系统的核心部分，用于实现物体识别、定位和三维点云处理等功能，为后续的抓取规划提供必要的环境感知信息。&lt;/p&gt;
&lt;p&gt;到此，我们的抓取的准备工作已经完成，也就是simulation.launch.py代码执行结束，接下来将执行start_grasp.launch.py代码。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（五）：register_depth文件详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-5</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-5</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:41:00 GMT</pubDate><content:encoded>&lt;p&gt;接下来解释simulation.launch.py代码中的register_depth.launch.py文件
代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    return LaunchDescription([

        ## 通过 rclcpp_components 容器启动插件
        launch_ros.actions.ComposableNodeContainer(
            name=&apos;container&apos;,                    ## 容器名称
            namespace=&apos;&apos;,                        ## 命名空间
            package=&apos;rclcpp_components&apos;,         ## 包名
            executable=&apos;component_container&apos;,    ## 可执行文件名
            composable_node_descriptions=[       ## 可组合节点描述列表
                launch_ros.descriptions.ComposableNode(
                    package=&apos;depth_image_proc&apos;,                  ## 包名
                    plugin=&apos;depth_image_proc::RegisterNode&apos;,     ## 插件名
                    name=&apos;register_node&apos;,                        ## 节点名称
                    remappings=[(&apos;depth/image_rect&apos;, &apos;/depth/image_raw&apos;),        ## 重映射深度图像话题
                                (&apos;depth/camera_info&apos;, &apos;/depth/camera_info&apos;),     ## 重映射深度相机信息话题
                                (&apos;rgb/camera_info&apos;, &apos;/color/camera_info&apos;),       ## 重映射RGB相机信息话题
                                (&apos;depth_registered/image_rect&apos;, &apos;/depth_registered/image_rect&apos;),  ## 重映射注册后的深度图像话题
                                (&apos;depth_registered/camera_info&apos;, &apos;/depth_registered/camera_info&apos;)] ## 重映射注册后的相机信息话题
                ),
            ],
            output=&apos;screen&apos;,  ## 输出到屏幕
        ),
    ])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;ComposableNodeContainer（可组合节点容器）&lt;/strong&gt;
这是ROS 2中一种高效的节点管理方式，允许多个节点在同一个进程中运行，减少进程间通信开销。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;name=&apos;container&apos;：容器的名称，用于标识这个节点容器&lt;/li&gt;
&lt;li&gt;namespace=&apos;&apos;：命名空间，空字符串表示使用全局命名空间&lt;/li&gt;
&lt;li&gt;package=&apos;rclcpp_components&apos;：使用rclcpp_components包，这是ROS 2提供的组件容器功能&lt;/li&gt;
&lt;li&gt;executable=&apos;component_container&apos;：运行的可执行文件是组件容器&lt;/li&gt;
&lt;li&gt;output=&apos;screen&apos;：将容器的输出打印到终端屏幕&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;ComposableNode（可组合节点）&lt;/strong&gt;
这是一种特殊的节点定义方式，节点作为组件在容器内运行。&lt;/p&gt;
&lt;p&gt;当前配置中只启用了RegisterNode组件：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;package=&apos;depth_image_proc&apos;：使用depth_image_proc包，这是ROS提供的深度图像处理功能包&lt;/li&gt;
&lt;li&gt;plugin=&apos;depth_image_proc::RegisterNode&apos;：使用RegisterNode插件，用于深度图像配准（将深度图像对齐到彩色图像坐标系）&lt;/li&gt;
&lt;li&gt;name=&apos;register_node&apos;：节点名称&lt;/li&gt;
&lt;li&gt;remappings：话题重映射列表，将节点内部使用的话题名称映射到实际的话题名称&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;话题重映射详解&lt;/strong&gt;
话题重映射是ROS 2中非常重要的概念，它允许我们在不修改节点代码的情况下改变节点订阅和发布的话题名称。&lt;/p&gt;
&lt;p&gt;RegisterNode的重映射包括：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;(&apos;depth/image_rect&apos;, &apos;/depth/image_raw&apos;)：将节点内部的depth/image_rect话题映射到系统中的/depth/image_raw话题，这是原始深度图像数据&lt;/li&gt;
&lt;li&gt;(&apos;depth/camera_info&apos;, &apos;/depth/camera_info&apos;)：深度相机的内参信息话题&lt;/li&gt;
&lt;li&gt;(&apos;rgb/camera_info&apos;, &apos;/color/camera_info&apos;)：彩色相机的内参信息话题&lt;/li&gt;
&lt;li&gt;(&apos;depth_registered/image_rect&apos;, &apos;/depth_registered/image_rect&apos;)：注册后的深度图像输出话题&lt;/li&gt;
&lt;li&gt;(&apos;depth_registered/camera_info&apos;, &apos;/depth_registered/camera_info&apos;)：注册后的深度图像相机信息输出话题&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;总结&lt;/strong&gt;
RegisterNode的主要功能是将深度图像配准到彩色图像的坐标系中。在双目摄像头或RGB-D摄像头中，深度相机和彩色相机通常位于不同位置，因此它们捕获的图像具有不同的视角。为了将深度信息与彩色图像信息融合使用，需要将深度图像变换到彩色图像的坐标系中，这就是所谓的&quot;配准&quot;(register)过程。&lt;/p&gt;
&lt;p&gt;这个launch文件的作用就是启动深度图像配准功能，使得系统能够生成与彩色图像对齐的深度图像数据，这对于后续的物体识别、抓取点计算等任务非常重要。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（四）：ur5e_gripper_moveit文件详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-4</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-4</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 23:39:00 GMT</pubDate><content:encoded>&lt;p&gt;接下来详细解释ur5e_gripper_moveit.launch.py文件内容&lt;/p&gt;
&lt;h2&gt;第一步，先看generate_launch_description函数&lt;/h2&gt;
&lt;p&gt;代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    &quot;&quot;&quot;
    生成启动描述文件。
    
    该函数声明所有必要的启动参数，包括描述包、配置文件、关节限制等，
    然后通过OpaqueFunction调用launch_setup函数来配置和启动相关节点。
    
    返回:
    LaunchDescription: 包含所有声明参数和启动设置的启动描述对象
    &quot;&quot;&quot;
    declared_arguments = []
    ## General arguments
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;description_package&quot;,
            default_value=&quot;ur5e_gripper_moveit_config&quot;,
            description=&quot;Description package with robot URDF/XACRO files. Usually the argument \
        is not set, it enables use of a custom description.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;description_file&quot;,
            default_value=&quot;ur5e_gripper.urdf.xacro&quot;,
            description=&quot;URDF/XACRO description file with the robot.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;moveit_config_package&quot;,
            default_value=&quot;ur5e_gripper_moveit_config&quot;,
            description=&quot;MoveIt config package with robot SRDF/XACRO files. Usually the argument \
        is not set, it enables use of a custom moveit config.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;moveit_config_file&quot;,
            default_value=&quot;ur5e_gripper.srdf.xacro&quot;,
            description=&quot;MoveIt SRDF/XACRO description file with the robot.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;moveit_joint_limits_file&quot;,
            default_value=&quot;joint_limits.yaml&quot;,
            description=&quot;MoveIt joint limits that augment or override the values from the URDF robot_description.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;use_sim_time&quot;,
            default_value=&quot;true&quot;,
            description=&quot;Make MoveIt to use simulation time. This is needed for the trajectory planing in simulation.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;prefix&quot;,
            default_value=&apos;&quot;&quot;&apos;,
            description=&quot;Prefix of the joint names, useful for \
        multi-robot setup. If changed than also joint names in the controllers&apos; configuration \
        have to be updated.&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(&quot;launch_rviz&quot;, default_value=&quot;true&quot;, description=&quot;Launch RViz?&quot;)
    )

    return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里稍微介绍一下ur5e_gripper.srdf.xacro文件，ur5e_gripper.srdf.xacro 是一个SRDF (Semantic Robot Description Format) 文件，它是MoveIt运动规划框架中的一个重要配置文件。SRDF文件用于定义机器人的语义信息，这些信息扩展了基本的URDF（Unified Robot Description Format）模型。这个文件由Setup Assistant生成。
代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-xml&quot;&gt;&amp;#x3C;?xml version=&quot;1.0&quot; encoding=&quot;UTF-8&quot;?&gt;
&amp;#x3C;robot xmlns:xacro=&quot;http://wiki.ros.org/xacro&quot; name=&quot;ur5e_gripper&quot;&gt;

    &amp;#x3C;xacro:include filename=&quot;$(find ur_moveit_config)/srdf/ur_macro.srdf.xacro&quot; /&gt;
    &amp;#x3C;xacro:ur_srdf name=&quot;ur&quot; prefix=&quot;&quot; /&gt;


    &amp;#x3C;xacro:include filename=&quot;$(find robotiq_moveit_config)/srdf/robotiq_macro.srdf.xacro&quot; /&gt;
    &amp;#x3C;xacro:robotiq_srdf prefix=&quot;&quot; /&gt;



    &amp;#x3C;group_state name=&quot;ready&quot; group=&quot;ur_manipulator&quot;&gt;
        &amp;#x3C;joint name=&quot;elbow_joint&quot; value=&quot;1.5707&quot; /&gt;
        &amp;#x3C;joint name=&quot;shoulder_lift_joint&quot; value=&quot;-1.5707&quot; /&gt;
        &amp;#x3C;joint name=&quot;shoulder_pan_joint&quot; value=&quot;0&quot; /&gt;
        &amp;#x3C;joint name=&quot;wrist_1_joint&quot; value=&quot;-1.5707&quot; /&gt;
        &amp;#x3C;joint name=&quot;wrist_2_joint&quot; value=&quot;-1.5707&quot; /&gt;
        &amp;#x3C;joint name=&quot;wrist_3_joint&quot; value=&quot;0&quot; /&gt;
    &amp;#x3C;/group_state&gt;


    &amp;#x3C;disable_collisions link1=&quot;robotiq_85_base_link&quot; link2=&quot;wrist_1_link&quot; reason=&quot;Never&quot;/&gt;
    &amp;#x3C;disable_collisions link1=&quot;robotiq_85_base_link&quot; link2=&quot;wrist_2_link&quot; reason=&quot;Never&quot;/&gt;
    &amp;#x3C;disable_collisions link1=&quot;robotiq_85_base_link&quot; link2=&quot;wrist_3_link&quot; reason=&quot;Adjacent&quot;/&gt;
    
&amp;#x3C;/robot&gt;

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个文件的主要作用包括：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;定义机器人组（Groups）：指定机器人的关节和链接如何组成不同的运动学组，例如机械臂组和夹爪组。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;定义末端执行器（End Effectors）：指定机器人末端执行器与哪个链接相连，以及使用哪个组作为其运动学组。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;定义虚拟关节（Virtual Joints）：定义机器人与世界坐标系的连接关系。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;定义被动关节（Passive Joints）：标记那些不受主动控制的关节。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;定义机器人自碰撞（Self-Collision）：指定哪些链接之间可能发生碰撞，哪些可以忽略。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;总结&lt;/strong&gt;
这段XML代码定义了一个UR5e机械臂与Robotiq夹爪的机器人系统配置：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;声明使用xacro命名空间和UTF-8编码&lt;/li&gt;
&lt;li&gt;包含UR5e和Robotiq的SRDF宏定义文件&lt;/li&gt;
&lt;li&gt;定义名为&quot;ready&quot;的预设关节姿态&lt;/li&gt;
&lt;li&gt;禁用特定连杆间的碰撞检测，避免运动规划时误判&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;第二步，看launch_setup函数&lt;/h2&gt;
&lt;p&gt;代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def launch_setup(context, *args, **kwargs):
    &quot;&quot;&quot;
    设置并启动MoveIt相关的节点。
    
    该函数配置机器人描述、运动规划参数、控制器参数等，并启动move_group节点。
    
    参数:
    context: 启动上下文对象
    *args: 可变位置参数
    **kwargs: 可变关键字参数
    
    返回:
    list: 需要启动的节点列表
    &quot;&quot;&quot;

    ## General arguments
    description_package = LaunchConfiguration(&quot;description_package&quot;)
    description_file = LaunchConfiguration(&quot;description_file&quot;)
    moveit_config_package = LaunchConfiguration(&quot;moveit_config_package&quot;)
    moveit_joint_limits_file = LaunchConfiguration(&quot;moveit_joint_limits_file&quot;)
    moveit_config_file = LaunchConfiguration(&quot;moveit_config_file&quot;)
    prefix = LaunchConfiguration(&quot;prefix&quot;)
    use_sim_time = LaunchConfiguration(&quot;use_sim_time&quot;)
    ## launch_rviz = LaunchConfiguration(&quot;launch_rviz&quot;)


    ## 生成机器人的URDF描述内容，通过xacro工具处理URDF/XACRO文件
    robot_description_content = Command(
        [
            PathJoinSubstitution([FindExecutable(name=&quot;xacro&quot;)]),
            &quot; &quot;,
            PathJoinSubstitution([FindPackageShare(description_package), &quot;urdf&quot;, description_file]),
        ]
    )
    robot_description = {&quot;robot_description&quot;: robot_description_content}

    ## MoveIt Configuration
    ## 生成机器人的SRDF语义描述内容，通过xacro工具处理SRDF/XACRO文件
    robot_description_semantic_content = Command(
        [
            PathJoinSubstitution([FindExecutable(name=&quot;xacro&quot;)]),
            &quot; &quot;,
            PathJoinSubstitution(
                [FindPackageShare(moveit_config_package), &quot;srdf&quot;, moveit_config_file]
            ),
        ]
    )
    robot_description_semantic = {&quot;robot_description_semantic&quot;: robot_description_semantic_content}

    ## 获取运动学配置文件路径
    robot_description_kinematics = PathJoinSubstitution(
        [FindPackageShare(moveit_config_package), &quot;config&quot;, &quot;kinematics.yaml&quot;]
    )

    ## 加载关节限制配置文件
    robot_description_planning = {
        &quot;robot_description_planning&quot;: load_yaml(
            str(moveit_config_package.perform(context)),
            os.path.join(&quot;config&quot;, str(moveit_joint_limits_file.perform(context))),
        )
    }

    ## Planning Configuration
    ## 配置OMPL运动规划管道参数
    ompl_planning_pipeline_config = {
        &quot;move_group&quot;: {
            &quot;planning_plugin&quot;: &quot;ompl_interface/OMPLPlanner&quot;,
            &quot;request_adapters&quot;: &quot;&quot;&quot;default_planner_request_adapters/AddTimeOptimalParameterization default_planner_request_adapters/FixWorkspaceBounds default_planner_request_adapters/FixStartStateBounds default_planner_request_adapters/FixStartStateCollision default_planner_request_adapters/FixStartStatePathConstraints&quot;&quot;&quot;,
            &quot;start_state_max_bounds_error&quot;: 0.1,
        }
    }
    ## 加载OMPL规划配置并更新到move_group配置中
    ompl_planning_yaml = load_yaml(&quot;ur5e_gripper_moveit_config&quot;, &quot;config/ompl_planning.yaml&quot;)
    ompl_planning_pipeline_config[&quot;move_group&quot;].update(ompl_planning_yaml)

    ## Trajectory Execution Configuration
    ## 加载控制器配置文件
    controllers_yaml = load_yaml(&quot;ur5e_gripper_moveit_config&quot;, &quot;config/moveit_controllers.yaml&quot;)
    
    ## 配置MoveIt控制器管理器和轨迹执行参数
    moveit_controllers = {
        &quot;moveit_simple_controller_manager&quot;: controllers_yaml,
        &quot;moveit_controller_manager&quot;: &quot;moveit_simple_controller_manager/MoveItSimpleControllerManager&quot;,
        &quot;trajectory_execution&quot; : {
            &quot;allowed_execution_duration_scaling&quot;: 2.0,  ## change execution time scaling here
            &quot;allowed_goal_duration_margin&quot;: 0.5,
            &quot;allowed_start_tolerance&quot;: 0.01,
        }
    }

    ## 配置轨迹执行参数
    trajectory_execution = {
        &quot;moveit_manage_controllers&quot;: False,
        &quot;trajectory_execution.allowed_execution_duration_scaling&quot;: 1.2,
        &quot;trajectory_execution.allowed_goal_duration_margin&quot;: 0.5,
        &quot;trajectory_execution.allowed_start_tolerance&quot;: 0.01,
    }

    ## 配置规划场景监控参数
    planning_scene_monitor_parameters = {
        &quot;publish_planning_scene&quot;: True,
        &quot;publish_geometry_updates&quot;: True,
        &quot;publish_state_updates&quot;: True,
        &quot;publish_transforms_updates&quot;: True,
        &quot;publish_robot_description&quot;:True,
        &quot;publish_robot_description_semantic&quot;:True,
    }

    ## 配置OctoMap参数和传感器配置
    octomap_config = {&apos;octomap_frame&apos;: &apos;camera_depth_optical_frame&apos;, &apos;octomap_resolution&apos;: 0.02}
    octomap_updater_config = load_yaml(&apos;ur5e_gripper_moveit_config&apos;, &apos;config/sensors_3d.yaml&apos;)    
    ## Start the actual move_group node/action server
    ## 创建move_group节点，这是MoveIt的主要节点
    move_group_node = Node(
        package=&quot;moveit_ros_move_group&quot;,
        executable=&quot;move_group&quot;,
        output=&quot;screen&quot;,
        parameters=[
            robot_description,
            robot_description_semantic,
            robot_description_kinematics,
            robot_description_planning,
            ompl_planning_pipeline_config,
            trajectory_execution,
            moveit_controllers,
            planning_scene_monitor_parameters,
            {&quot;use_sim_time&quot;: use_sim_time},
            octomap_config,
            octomap_updater_config,
        ],
    )


    nodes_to_start = [
        move_group_node, 
        ## rviz_node, 
    ]

    return nodes_to_start
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;URDF文件和SRDF文件&lt;/strong&gt;
关于URDF文件和SRDF文件的关系可以参考&lt;a href=&quot;https://blog.csdn.net/Bing_Lee/article/details/130922961?ops_request_misc=%257B%2522request%255Fid%2522%253A%25223c48ea5c069117926431d55583743900%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&amp;#x26;request_id=3c48ea5c069117926431d55583743900&amp;#x26;biz_id=0&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-1-130922961-null-null.142%5Ev102%5Epc_search_result_base7&amp;#x26;utm_term=urdf%E5%92%8Csrdf%E6%96%87%E4%BB%B6%E4%B9%8B%E9%97%B4%E7%9A%84%E5%85%B3%E7%B3%BB&amp;#x26;spm=1018.2226.3001.4187&quot;&gt;博客&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;这里简单总结两者之间的关系：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;依赖关系：SRDF文件依赖于URDF文件，它是在URDF基础上添加语义信息的扩展。&lt;/li&gt;
&lt;li&gt;功能互补：URDF描述机器人的物理结构，SRDF描述机器人的运动学和规划相关语义信息。&lt;/li&gt;
&lt;li&gt;协同工作：在MoveIt中，URDF和SRDF通常一起使用，URDF提供基本的机器人描述，SRDF提供运动规划所需的高级语义信息。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;kinematics.yaml文件&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;kinematics.yaml文件是MoveIt运动规划框架中一个非常重要的配置文件，用于配置机器人的运动学求解器。它定义了机器人各运动学组如何进行正向和逆向运动学计算。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;joint_limits.yaml文件&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;顾名思义：关节活动限制文件&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;运动规划配置&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;ompl_planning_pipeline_config = {
        &quot;move_group&quot;: {
            &quot;planning_plugin&quot;: &quot;ompl_interface/OMPLPlanner&quot;,
            &quot;request_adapters&quot;: &quot;&quot;&quot;default_planner_request_adapters/AddTimeOptimalParameterization default_planner_request_adapters/FixWorkspaceBounds default_planner_request_adapters/FixStartStateBounds default_planner_request_adapters/FixStartStateCollision default_planner_request_adapters/FixStartStatePathConstraints&quot;&quot;&quot;,
            &quot;start_state_max_bounds_error&quot;: 0.1,
        }
    }
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码是ROS 2 MoveIt配置中的重要部分，用于设置OMPL（Open Motion Planning Library）运动规划管道。OMPL是MoveIt中使用的默认运动规划库，提供了多种规划算法。&lt;/p&gt;
&lt;p&gt;这部分创建了一个基本的OMPL规划管道配置：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;planning_plugin: 指定使用OMPL作为规划器&lt;/li&gt;
&lt;li&gt;request_adapters: 定义规划请求处理适配器列表，这些适配器在规划前对请求进行预处理：
&lt;ul&gt;
&lt;li&gt;AddTimeOptimalParameterization: 添加时间最优参数化&lt;/li&gt;
&lt;li&gt;FixWorkspaceBounds: 修复工作空间边界&lt;/li&gt;
&lt;li&gt;FixStartStateBounds: 修复起始状态边界&lt;/li&gt;
&lt;li&gt;FixStartStateCollision: 修复起始状态碰撞&lt;/li&gt;
&lt;li&gt;FixStartStatePathConstraints: 修复起始状态路径约束&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;start_state_max_bounds_error: 设置起始状态最大边界误差为0.1&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;    ## 加载OMPL规划配置并更新到move_group配置中
    ompl_planning_yaml = load_yaml(&quot;ur5e_gripper_moveit_config&quot;, &quot;config/ompl_planning.yaml&quot;)
    ompl_planning_pipeline_config[&quot;move_group&quot;].update(ompl_planning_yaml)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这部分代码加载了额外的OMPL规划配置文件，并将其合并到基本配置中：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;使用load_yaml函数加载ompl_planning.yaml文件中的配置&lt;/li&gt;
&lt;li&gt;通过update方法将加载的配置合并到现有配置中
从我们查看的ompl_planning.yaml文件中可以看到，它定义了：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-yaml&quot;&gt;planner_configs:
  SBLkConfigDefault:
    type: geometric::SBL
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
  ESTkConfigDefault:
    type: geometric::EST
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0 setup()
    goal_bias: 0.05  ## When close to goal select goal, with this probability. default: 0.05
  LBKPIECEkConfigDefault:
    type: geometric::LBKPIECE
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
    border_fraction: 0.9  ## Fraction of time focused on boarder default: 0.9
    min_valid_path_fraction: 0.5  ## Accept partially valid moves above fraction. default: 0.5
  BKPIECEkConfigDefault:
    type: geometric::BKPIECE
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
    border_fraction: 0.9  ## Fraction of time focused on boarder default: 0.9
    failed_expansion_score_factor: 0.5  ## When extending motion fails, scale score by factor. default: 0.5
    min_valid_path_fraction: 0.5  ## Accept partially valid moves above fraction. default: 0.5
  KPIECEkConfigDefault:
    type: geometric::KPIECE
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
    goal_bias: 0.05  ## When close to goal select goal, with this probability. default: 0.05
    border_fraction: 0.9  ## Fraction of time focused on boarder default: 0.9 (0.0,1.]
    failed_expansion_score_factor: 0.5  ## When extending motion fails, scale score by factor. default: 0.5
    min_valid_path_fraction: 0.5  ## Accept partially valid moves above fraction. default: 0.5
  RRTkConfigDefault:
    type: geometric::RRT
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
    goal_bias: 0.05  ## When close to goal select goal, with this probability? default: 0.05
  RRTConnectkConfigDefault:
    type: geometric::RRTConnect
    range: 0.2  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
  RRTstarkConfigDefault:
    type: geometric::RRTstar
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
    goal_bias: 0.05  ## When close to goal select goal, with this probability? default: 0.05
    delay_collision_checking: 1  ## Stop collision checking as soon as C-free parent found. default 1
  TRRTkConfigDefault:
    type: geometric::TRRT
    range: 0.0  ## Max motion added to tree. ==&gt; maxDistance_ default: 0.0, if 0.0, set on setup()
    goal_bias: 0.05  ## When close to goal select goal, with this probability? default: 0.05
    max_states_failed: 10  ## when to start increasing temp. default: 10
    temp_change_factor: 2.0  ## how much to increase or decrease temp. default: 2.0
    min_temperature: 10e-10  ## lower limit of temp change. default: 10e-10
    init_temperature: 10e-6  ## initial temperature. default: 10e-6
    frountier_threshold: 0.0  ## dist new state to nearest neighbor to disqualify as frontier. default: 0.0 set in setup()
    frountierNodeRatio: 0.1  ## 1/10, or 1 nonfrontier for every 10 frontier. default: 0.1
    k_constant: 0.0  ## value used to normalize expression. default: 0.0 set in setup()
  PRMkConfigDefault:
    type: geometric::PRM
    max_nearest_neighbors: 10  ## use k nearest neighbors. default: 10
  PRMstarkConfigDefault:
    type: geometric::PRMstar
  longest_valid_segment_fraction: 0.005
  
ur_manipulator:
  default_planner_config: RRTstarkConfigDefault
  planning_attempts: 10      ## 尝试多次
  planning_time: 3.0         ## 规划超时时间拉长一点
  planner_configs:
    - SBLkConfigDefault
    - ESTkConfigDefault
    - LBKPIECEkConfigDefault
    - BKPIECEkConfigDefault
    - KPIECEkConfigDefault
    - RRTkConfigDefault
    - RRTConnectkConfigDefault
    - RRTstarkConfigDefault
    - TRRTkConfigDefault
    - PRMkConfigDefault
    - PRMstarkConfigDefault

&lt;/code&gt;&lt;/pre&gt;
&lt;ol&gt;
&lt;li&gt;多种规划算法的配置（如RRT、RRT*、PRM等）&lt;/li&gt;
&lt;li&gt;为ur_manipulator组指定了默认规划器为RRT*&lt;/li&gt;
&lt;li&gt;设置了规划尝试次数、规划时间等参数&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;控制器配置&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 加载控制器配置文件
    controllers_yaml = load_yaml(&quot;ur5e_gripper_moveit_config&quot;, &quot;config/moveit_controllers.yaml&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这部分代码加载了外部的控制器配置文件moveit_controllers.yaml。通过读取该文件，我们可以看到其中定义了两种控制器：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-yaml&quot;&gt;## MoveIt uses this configuration for controller management
## controller_names here should be the same as those in ros2 control

controller_names:
  - ur5e_arm_controller
  - gripper_controller
  
ur5e_arm_controller:
  type: FollowJointTrajectory
  action_ns: follow_joint_trajectory
  default: true
  joints:
    - shoulder_pan_joint
    - shoulder_lift_joint
    - elbow_joint
    - wrist_1_joint
    - wrist_2_joint
    - wrist_3_joint
  action_ns: follow_joint_trajectory
  default: true
gripper_controller:
  type: GripperCommand
  joints:
    - robotiq_85_left_knuckle_joint
  action_ns: gripper_cmd
  default: true
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;ur5e_arm_controller：机械臂控制器，类型为FollowJointTrajectory，用于控制机械臂的关节轨迹&lt;/li&gt;
&lt;li&gt;gripper_controller：夹爪控制器，类型为GripperCommand，用于控制夹爪的开合&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 配置MoveIt控制器管理器和轨迹执行参数
    moveit_controllers = {
        &quot;moveit_simple_controller_manager&quot;: controllers_yaml,
        &quot;moveit_controller_manager&quot;: &quot;moveit_simple_controller_manager/MoveItSimpleControllerManager&quot;,
        &quot;trajectory_execution&quot; : {
            &quot;allowed_execution_duration_scaling&quot;: 2.0,  ## change execution time scaling here
            &quot;allowed_goal_duration_margin&quot;: 0.5,
            &quot;allowed_start_tolerance&quot;: 0.01,
        }
    }
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这部分配置了MoveIt的控制器管理器和轨迹执行参数：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;moveit_simple_controller_manager：使用之前加载的控制器配置&lt;/li&gt;
&lt;li&gt;moveit_controller_manager：指定使用MoveItSimpleControllerManager作为控制器管理器&lt;/li&gt;
&lt;li&gt;trajectory_execution：轨迹执行相关参数设置
&lt;ul&gt;
&lt;li&gt;allowed_execution_duration_scaling: 允许的执行时间缩放比例为2.0，意味着允许轨迹执行的时间比计划的时间长100%&lt;/li&gt;
&lt;li&gt;allowed_goal_duration_margin: 允许的目标时间余量为0.5秒&lt;/li&gt;
&lt;li&gt;allowed_start_tolerance: 允许的起始状态误差为0.01&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 配置MoveIt控制器管理器和轨迹执行参数
    moveit_controllers = {
        &quot;moveit_simple_controller_manager&quot;: controllers_yaml,
        &quot;moveit_controller_manager&quot;: &quot;moveit_simple_controller_manager/MoveItSimpleControllerManager&quot;,
        &quot;trajectory_execution&quot; : {
            &quot;allowed_execution_duration_scaling&quot;: 2.0,  ## change execution time scaling here
            &quot;allowed_goal_duration_margin&quot;: 0.5,
            &quot;allowed_start_tolerance&quot;: 0.01,
        }
    }
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这部分是额外的轨迹执行参数配置：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;moveit_manage_controllers: 设置为False，表示不使用MoveIt管理控制器的生命周期&lt;/li&gt;
&lt;li&gt;其他参数与上面的配置类似，但执行时间缩放比例为1.2（比上面的配置宽松一些）&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;总结：些配置最终会被传递给MoveIt的move_group节点，用于控制机器人执行规划好的轨迹。配置中的参数对于确保轨迹执行的安全性和成功率非常重要，比如允许一定的时间余量可以防止因为执行稍慢而导致轨迹执行失败。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;配置规划场景监控参数&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 配置规划场景监控参数
    planning_scene_monitor_parameters = {
        &quot;publish_planning_scene&quot;: True,
        &quot;publish_geometry_updates&quot;: True,
        &quot;publish_state_updates&quot;: True,
        &quot;publish_transforms_updates&quot;: True,
        &quot;publish_robot_description&quot;:True,
        &quot;publish_robot_description_semantic&quot;:True,
    }
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;各项参数的含义如下：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&quot;publish_planning_scene&quot;: True - 启用规划场景的发布功能。规划场景包含了机器人、障碍物、工作空间等所有与运动规划相关的信息。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&quot;publish_geometry_updates&quot;: True - 启用几何更新的发布。当环境中的物体几何形状发生变化时，会发布这些更新。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&quot;publish_state_updates&quot;: True - 启用状态更新的发布。当机器人或环境的状态发生变化时（如关节角度变化），会发布这些更新。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&quot;publish_transforms_updates&quot;: True - 启用坐标变换更新的发布。当坐标变换关系发生变化时，会发布这些更新。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&quot;publish_robot_description&quot;: True - 发布机器人的描述信息（URDF）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&quot;publish_robot_description_semantic&quot;: True - 发布机器人的语义描述信息（SRDF）。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这些参数会被传递给MoveIt的move_group节点，用于控制规划场景监控器的行为。规划场景监控器负责维护和更新机器人工作环境的表示，包括机器人的当前状态、环境中的障碍物、以及机器人各部分之间的碰撞关系等。通过启用这些发布选项，系统中的其他组件可以实时获取到最新的环境和机器人状态信息，这对于安全、准确的运动规划至关重要。&lt;/p&gt;
&lt;p&gt;配置OctoMap（八叉树地图）和3D传感器&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 配置OctoMap参数和传感器配置
    octomap_config = {&apos;octomap_frame&apos;: &apos;camera_depth_optical_frame&apos;, &apos;octomap_resolution&apos;: 0.02}
    octomap_updater_config = load_yaml(&apos;ur5e_gripper_moveit_config&apos;, &apos;config/sensors_3d.yaml&apos;)  
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;第一行定义了OctoMap的基本配置参数：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;octomap_frame: 设置了OctoMap的参考坐标系为&apos;camera_depth_optical_frame&apos;，这是深度相机的光学坐标系&lt;/li&gt;
&lt;li&gt;octomap_resolution: 设置了OctoMap的分辨率为0.02米，即2厘米，这决定了地图的精细程度&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;第二行通过load_yaml函数加载了sensors_3d.yaml配置文件，该文件内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-yaml&quot;&gt;sensors:  
  - point_cloud_camera

point_cloud_camera:
    sensor_plugin: occupancy_map_monitor/PointCloudOctomapUpdater
    point_cloud_topic: /depth/points_filtered
    max_range: 5.0
    point_subsample: 1
    padding_offset: 0.05
    padding_scale: 1.0
    max_update_rate: 1.0
    filtered_cloud_topic: filtered_cloud

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个配置文件定义了3D传感器的相关参数：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;sensors: 定义了使用的传感器列表，这里只有一个名为point_cloud_camera的点云相机&lt;/li&gt;
&lt;li&gt;point_cloud_camera部分详细配置了点云相机的参数：
&lt;ul&gt;
&lt;li&gt;sensor_plugin: 使用occupancy_map_monitor/PointCloudOctomapUpdater作为传感器插件，用于将点云数据更新到占用地图&lt;/li&gt;
&lt;li&gt;point_cloud_topic: 指定接收点云数据的话题为/depth/points_filtered&lt;/li&gt;
&lt;li&gt;max_range: 设置点云数据的最大处理距离为5.0米&lt;/li&gt;
&lt;li&gt;point_subsample: 点云采样因子为1，即不进行采样&lt;/li&gt;
&lt;li&gt;padding_offset和padding_scale: 碰撞检测的填充参数，用于在规划时增加安全距离&lt;/li&gt;
&lt;li&gt;max_update_rate: 最大更新频率为1.0Hz&lt;/li&gt;
&lt;li&gt;filtered_cloud_topic: 指定发布过滤后点云的话题为filtered_cloud&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;这些配置参数最终会被传递给MoveIt的move_group节点，用于启用和配置3D感知功能。通过这些配置，MoveIt可以实时接收点云数据，构建环境的3D占用地图（OctoMap），并在运动规划时考虑环境中的障碍物，从而实现更安全、更智能的路径规划。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;创建move_group节点&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;move_group_node = Node(
        package=&quot;moveit_ros_move_group&quot;,
        executable=&quot;move_group&quot;,
        output=&quot;screen&quot;,
        parameters=[
            robot_description,
            robot_description_semantic,
            robot_description_kinematics,
            robot_description_planning,
            ompl_planning_pipeline_config,
            trajectory_execution,
            moveit_controllers,
            planning_scene_monitor_parameters,
            {&quot;use_sim_time&quot;: use_sim_time},
            octomap_config,
            octomap_updater_config,
        ],
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个代码块创建了一个ROS 2节点，该节点是MoveIt运动规划框架的核心组件。具体解释如下：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;Node定义：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;package=&quot;moveit_ros_move_group&quot;: 指定节点所属的ROS包为moveit_ros_move_group&lt;/li&gt;
&lt;li&gt;executable=&quot;move_group&quot;: 指定要运行的可执行文件为move_group，这是MoveIt的主要入口点&lt;/li&gt;
&lt;li&gt;output=&quot;screen&quot;: 将节点输出打印到终端屏幕&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;parameters参数列表包含了MoveIt系统运行所需的所有配置信息：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;robot_description：机器人的URDF描述，包含机器人的物理结构信息&lt;/li&gt;
&lt;li&gt;robot_description_semantic：机器人的SRDF语义描述，包含规划组、末端执行器等语义信息&lt;/li&gt;
&lt;li&gt;robot_description_kinematics：运动学求解器配置，指定使用哪种IK算法&lt;/li&gt;
&lt;li&gt;robot_description_planning：机器人规划相关参数，如关节限制等&lt;/li&gt;
&lt;li&gt;ompl_planning_pipeline_config：OMPL运动规划管道配置，包括规划算法和适配器设置&lt;/li&gt;
&lt;li&gt;trajectory_execution：轨迹执行参数，控制轨迹执行的时间和容错设置&lt;/li&gt;
&lt;li&gt;moveit_controllers：控制器管理配置，指定如何与底层控制器通信&lt;/li&gt;
&lt;li&gt;planning_scene_monitor_parameters：规划场景监控参数，控制场景信息的发布和更新&lt;/li&gt;
&lt;li&gt;{&quot;use_sim_time&quot;: use_sim_time}：是否使用仿真时间&lt;/li&gt;
&lt;li&gt;octomap_config：OctoMap配置，用于3D环境感知&lt;/li&gt;
&lt;li&gt;octomap_updater_config：OctoMap更新器配置，指定如何从传感器数据更新地图&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;总结&lt;/strong&gt;
这个move_group节点是MoveIt系统的核心，它整合了所有的配置信息，提供了以下关键功能：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;运动规划：基于OMPL规划器提供路径规划功能&lt;/li&gt;
&lt;li&gt;运动学计算：提供正向和逆向运动学求解&lt;/li&gt;
&lt;li&gt;碰撞检测：实时检测机器人与环境的碰撞&lt;/li&gt;
&lt;li&gt;轨迹执行：控制机器人按照规划的轨迹执行动作&lt;/li&gt;
&lt;li&gt;3D感知：通过OctoMap处理传感器数据，构建环境地图&lt;/li&gt;
&lt;li&gt;场景管理：维护和更新规划场景信息&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;通过将所有这些配置参数传递给move_group节点，系统能够完整地运行MoveIt的所有功能，实现复杂的机器人运动规划和控制任务。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目环境搭建</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_environment</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_environment</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 15:56:00 GMT</pubDate><content:encoded>&lt;p&gt;推荐使用WSL启动，虚拟机太卡了，安装&lt;a href=&quot;https://www.cnblogs.com/xiao987334176/p/18864140&quot;&gt;WSL2教程&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;部署环境&lt;/h2&gt;
&lt;p&gt;准备一个Ubuntu 22.04环境，在 VMware 虚拟机中，确保网络模式是 NAT 或 桥接模式。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;lsb_release -a
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_environment%2F1.png%3ForigWidth%3D824%26origHeight%3D261%26origFormat%3Dpng&amp;#x26;w=824&amp;#x26;h=261&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;更换阿里云镜像源&lt;/h2&gt;
&lt;p&gt;备份原始源&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo cp /etc/apt/sources.list /etc/apt/sources.list.bak
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;更换为阿里云镜像源&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo nano /etc/apt/sources.list
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;删除原来的内容，并将下面的内容复制进去&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;deb https://mirrors.aliyun.com/ubuntu/ jammy main restricted universe multiverse
deb https://mirrors.aliyun.com/ubuntu/ jammy-updates main restricted universe multiverse
deb https://mirrors.aliyun.com/ubuntu/ jammy-backports main restricted universe multiverse
deb https://mirrors.aliyun.com/ubuntu/ jammy-security main restricted universe multiverse
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;更新软件包&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt-get update
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;推荐安装open-vm-tools，方便主机与虚拟机之间直接复制粘贴&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt-get install open-vm-tools-desktop
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;安装完open-vm-tools后，重启虚拟机即可。&lt;/p&gt;
&lt;h2&gt;安装ROS2&lt;/h2&gt;
&lt;p&gt;这里推荐使用鱼香ROS一键安装，非常方便~&lt;/p&gt;
&lt;p&gt;下载并执行脚本&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;wget http://fishros.com/install -O fishros &amp;#x26;&amp;#x26; . fishros
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;记得选择安装ROS2 Humble以及VSCode&lt;/p&gt;
&lt;p&gt;完成ROS2的安装后，配置环境变量&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;echo &quot;source /opt/ros/humble/setup.bash&quot; &gt;&gt; ~/.bashrc
source ~/.bashrc
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;安装Moveit2&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt update
sudo apt install ros-humble-moveit-*
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;因为项目需要物理仿真支持，所以需要安装gazebo&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt install ros-humble-moveit-common ros-humble-moveit-ros-visualization
sudo apt install ros-humble-gazebo-ros-pkgs ros-humble-gazebo-ros2-control
sudo apt install ros-humble-image-pipeline
sudo apt install ros-humble-compressed-image-transport
sudo apt install ros-humble-compressed-depth-image-transport
sudo apt install ros-humble-vision-msgs
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;安装UR5e机械臂依赖&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt install ros-humble-ur-client-library
sudo apt install ros-humble-ur-description ros-humble-ur-moveit-config
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;安装OctoMap依赖&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt install ros-humble-octomap-*
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;安装深度相机驱动Realsense与目标检测环境YOLO&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo apt install ros-humble-realsense2-camera
pip install ultralytics
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;克隆项目仓库&lt;/h2&gt;
&lt;p&gt;首先创建工作区&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;mkdir -p ~/ros2_ws/src
cd ~/ros2_ws/src
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;克隆项目源码&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;git clone https://github.com/SoupCola/ros2_moveit2_ur5e_grasp.git
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;手动创建rosdep配置并更新&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo mkdir -p /etc/ros/rosdep/sources.list.d
cat &amp;#x3C;&amp;#x3C;EOF | sudo tee /etc/ros/rosdep/sources.list.d/20-default.list
yaml https://ghproxy.com/https://raw.githubusercontent.com/ros/rosdistro/master/rosdep/osx-homebrew.yaml osx
yaml https://ghproxy.com/https://raw.githubusercontent.com/ros/rosdistro/master/rosdep/base.yaml
yaml https://ghproxy.com/https://raw.githubusercontent.com/ros/rosdistro/master/rosdep/python.yaml
yaml https://ghproxy.com/https://raw.githubusercontent.com/ros/rosdistro/master/rosdep/ruby.yaml
EOF

rosdep update
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;安装依赖并编译&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;cd ~/ros2_ws
rosdep install --from-paths src --ignore-src -r -y
colcon build --symlink-install
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;编译完成后，source工作区&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;source ~/ros2_ws/install/setup.bash
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;WSL环境下的图形界面配置&lt;/h2&gt;
&lt;p&gt;如果你在使用WSL(Windows Subsystem for Linux)环境，需要配置以下环境变量来解决图形界面显示问题：&lt;/p&gt;
&lt;h3&gt;Gazebo完整配置&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;export GAZEBO_IP=127.0.0.1
export DISPLAY=$(cat /etc/resolv.conf | grep nameserver | awk &apos;{print $2}&apos;):0
export LIBGL_ALWAYS_INDIRECT=0
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这些命令的作用是：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;code&gt;export GAZEBO_IP=127.0.0.1&lt;/code&gt;：设置Gazebo服务器的IP地址为本地回环地址（Gazebo专用）&lt;/li&gt;
&lt;li&gt;&lt;code&gt;export DISPLAY=$(...)&lt;/code&gt;：设置X11显示环境变量，从DNS配置中获取nameserver地址作为显示服务器&lt;/li&gt;
&lt;li&gt;&lt;code&gt;export LIBGL_ALWAYS_INDIRECT=0&lt;/code&gt;：设置OpenGL渲染模式为直接渲染&lt;/li&gt;
&lt;/ol&gt;
&lt;h3&gt;RViz配置&lt;/h3&gt;
&lt;p&gt;对于RViz，只需要配置X11显示和OpenGL渲染环境变量（与Gazebo的显示相关部分相同）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;export DISPLAY=$(cat /etc/resolv.conf | grep nameserver | awk &apos;{print $2}&apos;):0
export LIBGL_ALWAYS_INDIRECT=0
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这两个命令的作用是：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;code&gt;export DISPLAY=$(...)&lt;/code&gt;：设置X11显示转发，使WSL中的图形应用能够显示在Windows上&lt;/li&gt;
&lt;li&gt;&lt;code&gt;export LIBGL_ALWAYS_INDIRECT=0&lt;/code&gt;：设置OpenGL直接渲染模式，提高图形性能&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;设置完成后，就可以正常启动RViz了：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;rviz2
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;注意：&lt;code&gt;GAZEBO_IP&lt;/code&gt;环境变量是专门为Gazebo网络通信设计的，而&lt;code&gt;DISPLAY&lt;/code&gt;和&lt;code&gt;LIBGL_ALWAYS_INDIRECT&lt;/code&gt;是通用的图形环境变量，适用于所有需要在WSL中显示的图形应用程序。&lt;/p&gt;
&lt;h2&gt;启动系统&lt;/h2&gt;
&lt;p&gt;每次启动前杀死之前的进程&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;killall -9 gzserver gzclient rviz2
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;1. 启动仿真环境 
ros2 launch ur_bringup simulation.launch.py
2. 启动抓取demo
ros2 launch ur_bringup start_grasp.launch.py
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;仿真环境启动成功
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_environment%2F2.png%3ForigWidth%3D2277%26origHeight%3D1360%26origFormat%3Dpng&amp;#x26;w=2277&amp;#x26;h=1360&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;有时候启动卡在了Gazebo的启动界面（或者机械臂加载不出来等），这可能是因为之前开启的gzserver gzclient没有正常关闭导致的，执行&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;killall -9 gzserver gzclient gazebo
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（三）：ur5e_gripper_sim_control文件详解</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-3</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-3</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 14:00:00 GMT</pubDate><content:encoded>&lt;p&gt;上一篇文章详细介绍了simulation.launch.py的内容，接下来我们进一步了解launch_setup函数中的代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;dual_ur5e_gripper_control_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;ur5e_gripper_moveit_config&quot;), &quot;/launch&quot;, &quot;/ur5e_gripper_sim_control.launch.py&quot;]
        ),
        launch_arguments={
            &quot;launch_rviz&quot;: &quot;true&quot;,
        }.items(),
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到这个代码调用了ur5e_gripper_moveit_config包下的launch目录下的ur5e_gripper_sim_control.launch.py文件，并传入参数launch_rviz，这个代码的作用是加载URDF模型并启动Gazebo仿真环境。&lt;/p&gt;
&lt;p&gt;注意，这里使用了FindPackageShare。在ROS 2中，FindPackageShare 函数会&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;首先查找已安装的包（在install目录中）&lt;/li&gt;
&lt;li&gt;然后才会查找系统安装的包（如/opt/ros/humble/share/）&lt;/li&gt;
&lt;li&gt;不会直接查找src目录中的源文件&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;接下来，我们查看ur5e_gripper_sim_control.launch.py的源代码（路径为&lt;code&gt;/home/whisper/ros2_ws/install/ur5e_gripper_moveit_config/share/ur5e_gripper_moveit_config/launch/ur5e_gripper_sim_control.launch.py&lt;/code&gt;，但是代码内容与src目录下是一样的）：&lt;/p&gt;
&lt;h2&gt;第一步，先看generate_launch_description函数&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    &quot;&quot;&quot;
    生成launch描述，定义launch文件的参数和入口点
    
    该函数声明所有launch参数并返回完整的LaunchDescription对象，
    包含参数声明和实际的launch设置函数。
    
    返回:
    LaunchDescription: 包含所有launch参数和设置函数的描述对象
    &quot;&quot;&quot;
    declared_arguments = []
    ## 通用参数
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;description_package&quot;,
            default_value=&quot;ur5e_gripper_moveit_config&quot;,
            description=&quot;包含机器人URDF/XACRO文件的描述包。通常不设置该参数，\
                         它允许使用自定义描述。&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;description_file&quot;,
            default_value=&quot;ur5e_gripper.urdf.xacro&quot;,
            description=&quot;包含机器人的URDF/XACRO描述文件。&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;prefix&quot;,
            default_value=&apos;&quot;&quot;&apos;,
            description=&quot;关节名称前缀，对多机器人设置很有用。\
                         如果更改了此参数，则控制器配置中的关节名称也必须更新。&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;start_joint_controller&quot;,
            default_value=&quot;true&quot;,
            description=&quot;为机器人控制启用无头模式&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(
            &quot;initial_joint_controller&quot;,
            default_value=&quot;ur5e_arm_controller&quot;,
            description=&quot;要启动的机器人控制器。&quot;,
        )
    )
    declared_arguments.append(
        DeclareLaunchArgument(&quot;launch_rviz&quot;, default_value=&quot;False&quot;, description=&quot;是否启动RViz？&quot;)
    )

    print(f&quot;\033[95mdeclared_arguments:{declared_arguments}\033[0m&quot;)
    return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)])

&lt;/code&gt;&lt;/pre&gt;
&lt;ol&gt;
&lt;li&gt;函数首先创建一个空列表declared_arguments，用于存储所有要声明的launch参数。这些参数可以在运行launch文件时通过命令行进行配置。&lt;/li&gt;
&lt;li&gt;我打印了declared_arguments内容如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-3%2F1.png%3ForigWidth%3D1365%26origHeight%3D194%26origFormat%3Dpng&amp;#x26;w=1365&amp;#x26;h=194&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
可以理解为generate_launch_description函数为launch_setup函数指定了需要加载的参数文件&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;第二步，看launch_setup函数&lt;/h2&gt;
&lt;p&gt;代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def launch_setup(context, *args, **kwargs):
    &quot;&quot;&quot;
    设置并启动UR5e机器人与夹爪的Gazebo仿真环境
    
    该函数配置机器人描述、控制器、可视化工具和Gazebo仿真环境，
    并按正确的依赖顺序启动所有必需的节点。
    
    参数:
    context: Launch上下文
    *args: 额外的位置参数
    **kwargs: 额外的关键字参数
    
    返回:
    list: 需要启动的节点列表
    &quot;&quot;&quot;

    ## 通用参数
    description_package = LaunchConfiguration(&quot;description_package&quot;)
    description_file = LaunchConfiguration(&quot;description_file&quot;)
    launch_rviz = LaunchConfiguration(&quot;launch_rviz&quot;)
    gazebo_world_file = os.path.join(
        FindPackageShare(package=&apos;ur5e_gripper_moveit_config&apos;).find(&apos;ur5e_gripper_moveit_config&apos;),
        &apos;gazebo&apos;,&apos;sim_env.world&apos;
    )
    
    ## 打印LaunchConfiguration的实际值
    print(f&quot;\033[95mdescription_package: {description_package.perform(context)}\033[0m&quot;)
    print(f&quot;\033[95mdescription_file: {description_file.perform(context)}\033[0m&quot;)
    print(f&quot;\033[95mlaunch_rviz: {launch_rviz.perform(context)}\033[0m&quot;)
    print(f&quot;\033[95mgazebo_world_file: {gazebo_world_file}\033[0m&quot;)


    ## RViz配置文件路径
    rviz_config_file = PathJoinSubstitution(
        [FindPackageShare(description_package), &quot;rviz&quot;, &quot;view_robot.rviz&quot;]
    )

    ## 机器人描述内容（通过xacro处理URDF文件生成）
    robot_description_content = Command(
        [
            PathJoinSubstitution([FindExecutable(name=&quot;xacro&quot;)]),
            &quot; &quot;,
            PathJoinSubstitution(
                [FindPackageShare(description_package), &quot;urdf&quot;, description_file]
            ),
        ]
    )
    robot_description = {&quot;robot_description&quot;: robot_description_content}

    ## 机器人状态发布节点
    robot_state_publisher_node = Node(
        package=&quot;robot_state_publisher&quot;,
        executable=&quot;robot_state_publisher&quot;,
        output=&quot;both&quot;,
        parameters=[{&quot;use_sim_time&quot;: True}, robot_description],
    )

    ## RViz可视化节点（根据launch_rviz参数决定是否启动）
    rviz_node = Node(
        package=&quot;rviz2&quot;,
        executable=&quot;rviz2&quot;,
        name=&quot;rviz2&quot;,
        output=&quot;log&quot;,
        arguments=[&quot;-d&quot;, rviz_config_file],
        condition=IfCondition(launch_rviz),
    )

    ## 关节状态广播器
    joint_state_broadcaster_spawner = Node(
        package=&quot;controller_manager&quot;,
        executable=&quot;spawner&quot;,
        arguments=[&quot;joint_state_broadcaster&quot;, &quot;--controller-manager&quot;, &quot;/controller_manager&quot;],
    )

    ## 在关节状态广播器启动后再启动RViz
    delay_rviz_after_joint_state_broadcaster_spawner = RegisterEventHandler(
        event_handler=OnProcessExit(
            target_action=joint_state_broadcaster_spawner,
            on_exit=[rviz_node],
        )
    )

    ## UR5e机械臂控制器
    ur5e_arm_controller_spawner = Node(
        package=&quot;controller_manager&quot;,
        executable=&quot;spawner&quot;,
        arguments=[&quot;ur5e_arm_controller&quot;, &quot;-c&quot;, &quot;/controller_manager&quot;],
    )

    ## Robotiq夹爪控制器
    robotiq_gripper_controller_spawner = Node(
        package=&quot;controller_manager&quot;,
        executable=&quot;spawner&quot;,
        arguments=[&quot;gripper_controller&quot;, &quot;-c&quot;, &quot;/controller_manager&quot;],
    )

    ## Gazebo仿真环境
    gazebo = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;gazebo_ros&quot;), &quot;/launch&quot;, &quot;/gazebo.launch.py&quot;]
        ),
        launch_arguments={
            &quot;world&quot;: gazebo_world_file,
        }.items(),
    )

    ## 在Gazebo中生成机器人实体
    gazebo_spawn_robot = Node(
        package=&quot;gazebo_ros&quot;,
        executable=&quot;spawn_entity.py&quot;,
        name=&quot;spawn_ur&quot;,
        arguments=[&quot;-entity&quot;, &quot;ur5e_gripper&quot;, &quot;-topic&quot;, &quot;robot_description&quot;],
        output=&quot;screen&quot;,
    )

    ## 需要启动的节点列表
    nodes_to_start = [
        robot_state_publisher_node,
        joint_state_broadcaster_spawner,
        delay_rviz_after_joint_state_broadcaster_spawner,
        ur5e_arm_controller_spawner,
        robotiq_gripper_controller_spawner,
        gazebo,
        gazebo_spawn_robot,
    ]

    return nodes_to_start
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来我们在launch_setup函数中对参数值进行解析并打印，内容如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-3%2F2.png%3ForigWidth%3D1359%26origHeight%3D140%26origFormat%3Dpng&amp;#x26;w=1359&amp;#x26;h=140&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;strong&gt;补充&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;DeclareLaunchArgument：用于声明一个可以在启动时传递的参数（命令行参数）。&lt;/li&gt;
&lt;li&gt;LaunchConfiguration：用于获取和使用在启动时通过 DeclareLaunchArgument 声明并传递的参数的值。&lt;/li&gt;
&lt;li&gt;在ROS 2的launch系统中，LaunchConfiguration是一个惰性求值的对象，它只在launch过程中才会被解析为实际值。因此，要获取其实际值，必须在launch上下文环境中使用.perform(context)方法。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;launch_setup代码解析一：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## RViz配置文件路径
    rviz_config_file = PathJoinSubstitution(
        [FindPackageShare(description_package), &quot;rviz&quot;, &quot;view_robot.rviz&quot;]
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码设置了rviz的路径，使用了PathJoinSubstitution组件，与下面的构建路径的方式不同。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;register_depth_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;vision&quot;), &quot;/launch&quot;, &quot;/register_depth.launch.py&quot;]
        ),
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;区别如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-3%2F3.png%3ForigWidth%3D1283%26origHeight%3D476%26origFormat%3Dpng&amp;#x26;w=1283&amp;#x26;h=476&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
PythonLaunchDescriptionSource是用来加载launch文件的，而PathJoinSubstitution是用来加载其他配置文件的。&lt;/p&gt;
&lt;h3&gt;launch_setup代码解析二：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 机器人描述内容（通过xacro处理URDF文件生成）
    robot_description_content = Command(
        [
            PathJoinSubstitution([FindExecutable(name=&quot;xacro&quot;)]),
            &quot; &quot;,
            PathJoinSubstitution(
                [FindPackageShare(description_package), &quot;urdf&quot;, description_file]
            ),
        ]
    )
    robot_description = {&quot;robot_description&quot;: robot_description_content}
    print(f&quot;\033[95mrobot_description: {robot_description}\033[0m&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码用于生成机器人的描述内容，通过调用xacro工具处理URDF文件来生成完整的机器人描述。&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Command&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;这是一个用于执行系统命令的工具类&lt;/li&gt;
&lt;li&gt;它允许在launch过程中动态执行命令并获取输出结果&lt;/li&gt;
&lt;li&gt;命令将在launch时执行，而不是在launch文件解析时执行&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;PathJoinSubstitution([FindExecutable(name=&quot;xacro&quot;)]):&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;FindExecutable(name=&quot;xacro&quot;)用于查找xacro可执行文件的完整路径&lt;/li&gt;
&lt;li&gt;PathJoinSubstitution在这里实际上只是将结果转换为字符串&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;空格字符串 &quot; &quot;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;这是在命令中添加空格分隔符&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;第二个PathJoinSubstitution:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;构建URDF文件的路径&lt;/li&gt;
&lt;li&gt;FindPackageShare(description_package)查找包含机器人描述的包&lt;/li&gt;
&lt;li&gt;&quot;urdf&quot;是子目录名&lt;/li&gt;
&lt;li&gt;description_file是URDF/XACRO文件名（默认为&quot;ur5e_gripper.urdf.xacro&quot;）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;整体功能是构建并执行以下形式的命令：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;xacro /home/whisper/ros2_ws/install/ur5e_gripper_moveit_config/share/ur5e_gripper_moveit_config/urdf/ur5e_gripper.urdf.xacro
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;执行这个命令会输出处理后的完整URDF描述，其中所有xacro宏和参数都已展开。&lt;/p&gt;
&lt;p&gt;最终，robot_description字典将被用作robot_state_publisher节点的参数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;robot_state_publisher_node = Node(
    package=&quot;robot_state_publisher&quot;,
    executable=&quot;robot_state_publisher&quot;,
    output=&quot;both&quot;,
    parameters=[{&quot;use_sim_time&quot;: True}, robot_description],  ## 在这里使用
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是一个Node对象的定义，用于在launch文件中启动一个ROS 2节点。每个节点都需要指定一个包(package)和可执行文件(executable)来运行。Node的理解参考&lt;a href=&quot;https://blog.csdn.net/weixin_42499608/article/details/118051082?spm=1001.2101.3001.6650.1&amp;#x26;utm_medium=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~PaidSort-1-118051082-blog-143952644.235%5Ev43%5Epc_blog_bottom_relevance_base3&amp;#x26;depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~PaidSort-1-118051082-blog-143952644.235%5Ev43%5Epc_blog_bottom_relevance_base3&amp;#x26;utm_relevant_index=2&quot;&gt;博客&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;各参数说明&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;package=&quot;robot_state_publisher&quot;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;指定节点所属的ROS 2包名&lt;/li&gt;
&lt;li&gt;robot_state_publisher是一个专门用于发布机器人状态的系统包&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;executable=&quot;robot_state_publisher&quot;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;指定要运行的可执行文件名&lt;/li&gt;
&lt;li&gt;这是robot_state_publisher包中的主可执行文件&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;output=&quot;both&quot;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;指定节点输出（日志）的处理方式&lt;/li&gt;
&lt;li&gt;&quot;both&quot;表示同时输出到屏幕和日志文件&lt;/li&gt;
&lt;li&gt;其他选项包括&quot;screen&quot;（仅屏幕）、&quot;log&quot;（仅日志文件）等&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;parameters=[{&quot;use_sim_time&quot;: True}, robot_description]:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;指定传递给节点的参数列表&lt;/li&gt;
&lt;li&gt;这是一个包含两个元素的列表：
&lt;ul&gt;
&lt;li&gt;{&quot;use_sim_time&quot;: True}: 设置参数，告诉节点使用仿真时间而不是系统时间
&lt;ul&gt;
&lt;li&gt;这在Gazebo仿真环境中非常重要，因为仿真时间可能与实际系统时间不同&lt;/li&gt;
&lt;li&gt;当设置为True时，节点将订阅/clock话题获取仿真时间&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;robot_description: 包含机器人描述的参数字典
&lt;ul&gt;
&lt;li&gt;这包含了通过xacro命令生成的完整URDF描述&lt;/li&gt;
&lt;li&gt;robot_state_publisher使用这个描述来了解机器人的结构和关节关系&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;launch_setup代码解析三：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## RViz可视化节点（根据launch_rviz参数决定是否启动）
rviz_node = Node(
    package=&quot;rviz2&quot;,
    executable=&quot;rviz2&quot;,
    name=&quot;rviz2&quot;,
    output=&quot;log&quot;,
    arguments=[&quot;-d&quot;, rviz_config_file],
    condition=IfCondition(launch_rviz),
)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;package=&quot;rviz2&quot;: 指定ROS 2包名，这里是RViz可视化工具包&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;executable=&quot;rviz2&quot;: 指定要运行的可执行文件名&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;name=&quot;rviz2&quot;: 指定节点名称&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;output=&quot;log&quot;: 输出设置为日志模式，将输出记录到日志文件而不是显示在终端&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;arguments=[&quot;-d&quot;, rviz_config_file]: 启动参数&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;-d参数指定RViz配置文件&lt;/li&gt;
&lt;li&gt;rviz_config_file是之前定义的配置文件路径&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;condition=IfCondition(launch_rviz): 启动条件&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;只有当launch_rviz参数为true时才会启动这个节点&lt;/li&gt;
&lt;li&gt;这使得RViz的启动变得可选和灵活&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;launch_setup代码解析四：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 关节状态广播器
    joint_state_broadcaster_spawner = Node(
        package=&quot;controller_manager&quot;,
        executable=&quot;spawner&quot;,
        arguments=[&quot;joint_state_broadcaster&quot;, &quot;--controller-manager&quot;, &quot;/controller_manager&quot;],
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是一个使用控制器管理器的spawner工具来启动关节状态广播器的节点。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;package=&quot;controller_manager&quot;: 指定ROS 2包名，这里是控制器管理器包&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;executable=&quot;spawner&quot;: 指定要运行的可执行文件名，spawner是用于加载和启动控制器的工具&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;arguments: 传递给spawner的参数列表&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;joint_state_broadcaster: 要启动的控制器名称，用于广播关节状态&lt;/li&gt;
&lt;li&gt;--controller-manager: 指定控制器管理器的参数&lt;/li&gt;
&lt;li&gt;/controller_manager: 控制器管理器的服务名称&lt;/li&gt;
&lt;li&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个节点的作用是启动关节状态广播器控制器，它会发布机器人关节的状态信息（如位置、速度等）到/joint_states话题，供其他节点（如robot_state_publisher）使用。这是机器人仿真和控制中必需的组件。&lt;/p&gt;
&lt;h3&gt;launch_setup代码解析五：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## UR5e机械臂控制器
    ur5e_arm_controller_spawner = Node(
        package=&quot;controller_manager&quot;,
        executable=&quot;spawner&quot;,
        arguments=[&quot;ur5e_arm_controller&quot;, &quot;-c&quot;, &quot;/controller_manager&quot;],
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是一个使用控制器管理器的spawner工具来启动UR5e机械臂控制器的节点。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;package=&quot;controller_manager&quot;: 指定ROS 2包名，这里是控制器管理器包&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;executable=&quot;spawner&quot;: 指定要运行的可执行文件名，spawner是用于加载和启动控制器的工具&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;arguments: 传递给spawner的参数列表&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;ur5e_arm_controller: 要启动的控制器名称，这是专门为UR5e机械臂配置的控制器&lt;/li&gt;
&lt;li&gt;-c: 控制器管理器的简写参数选项&lt;/li&gt;
&lt;li&gt;/controller_manager: 控制器管理器的服务名称&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这个节点的作用是启动UR5e机械臂控制器，它负责控制机械臂的关节运动。与之前启动的关节状态广播器不同，这个控制器是用来发送控制指令给机械臂的，使机械臂能够按照指定的轨迹运动。这是控制真实或仿真机械臂所必需的组件。&lt;/p&gt;
&lt;h3&gt;launch_setup代码解析六：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## Gazebo仿真环境
    gazebo = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;gazebo_ros&quot;), &quot;/launch&quot;, &quot;/gazebo.launch.py&quot;]
        ),
        launch_arguments={
            &quot;world&quot;: gazebo_world_file,
        }.items(),
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;PythonLaunchDescriptionSource: 指定要包含的launch文件来源&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;FindPackageShare(&quot;gazebo_ros&quot;): 查找gazebo_ros包的共享目录&lt;/li&gt;
&lt;li&gt;路径组合为: gazebo_ros包路径/launch/gazebo.launch.py&lt;/li&gt;
&lt;li&gt;这是Gazebo ROS包提供的标准launch文件&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;launch_arguments: 传递给被包含launch文件的参数&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&quot;world&quot;: gazebo_world_file: 指定Gazebo要加载的世界文件路径&lt;/li&gt;
&lt;li&gt;gazebo_world_file是之前定义的世界文件路径变量&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;注意，这里使用的gazebo.launch.py是官方包提供的&lt;/p&gt;
&lt;h3&gt;launch_setup代码解析七：&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 在Gazebo中生成机器人实体
    gazebo_spawn_robot = Node(
        package=&quot;gazebo_ros&quot;,
        executable=&quot;spawn_entity.py&quot;,
        name=&quot;spawn_ur&quot;,
        arguments=[&quot;-entity&quot;, &quot;ur5e_gripper&quot;, &quot;-topic&quot;, &quot;robot_description&quot;],
        output=&quot;screen&quot;,
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是一个用于在Gazebo仿真环境中生成机器人模型实体的节点，通过调用Gazebo ROS包提供的spawn_entity.py脚本来实现。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;package=&quot;gazebo_ros&quot;: 指定ROS 2包名，这里是Gazebo ROS接口包&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;executable=&quot;spawn_entity.py&quot;: 指定要运行的可执行文件名，这是Gazebo ROS包提供的用于在仿真中生成实体的Python脚本&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;name=&quot;spawn_ur&quot;: 指定节点名称&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;output=&quot;screen&quot;: 输出设置为屏幕模式，将输出显示在终端上&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;arguments: 传递给spawn_entity.py脚本的参数列表&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;-entity: 指定要创建的实体名称&lt;/li&gt;
&lt;li&gt;ur5e_gripper: 实体名称，这里命名为ur5e_gripper&lt;/li&gt;
&lt;li&gt;-topic: 指定从中获取机器人描述的话题&lt;/li&gt;
&lt;li&gt;robot_description: 话题名称，即之前定义的机器人描述参数&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;工作原理&lt;/strong&gt;
这个节点通过以下步骤工作：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;从robot_description话题获取机器人的URDF描述&lt;/li&gt;
&lt;li&gt;在Gazebo仿真环境中创建名为ur5e_gripper的实体&lt;/li&gt;
&lt;li&gt;将获取到的URDF描述应用到新创建的实体上&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;这个launch文件的主要功能是启动UR5e机械臂与Robotiq夹爪的完整Gazebo仿真环境，包括机器人模型加载、控制器配置、可视化界面和仿真世界设置。文件通过声明多个可配置参数（如描述包、描述文件、是否启动RViz等）来提供灵活性，并按正确顺序启动各个必要组件：首先启动Gazebo仿真环境和机器人状态发布器，然后加载关节状态广播器、机械臂控制器和夹爪控制器，最后在仿真环境中生成机器人实体。整个启动过程还考虑了依赖关系，例如确保在关节状态广播器启动后再启动RViz，以保证各组件能够正确协同工作。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（二）：从simulation.launch.py文件切入</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-2</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 13:48:00 GMT</pubDate><content:encoded>&lt;p&gt;我的整体学习流程就是从启动仿真环境、MoveIt配置和视觉模块开始逐步介绍。&lt;/p&gt;
&lt;p&gt;代码的第一步就是执行：&lt;code&gt;ros2 launch ur_bringup simulation.launch.py&lt;/code&gt;，所以从simulation.launch.py开始进行讲解。&lt;/p&gt;
&lt;p&gt;在ROS 2中，launch文件用于启动和配置多个节点、参数、服务等。它们帮助我们自动化复杂的系统启动过程。&lt;/p&gt;
&lt;h2&gt;执行launch文件的工作流程&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;当执行 ros2 launch ur_bringup simulation.launch.py 命令后，ROS 2 launch系统会寻找并执行 generate_launch_description() 函数，这是每个launch文件的入口点。&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def generate_launch_description():
    declared_arguments = []

    return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;declared_arguments = []&lt;/code&gt;，这里创建了一个空列表，用于存储launch文件可以接受的参数。在这个例子中，没有声明任何参数，所以列表是空的。如果需要添加参数，可以使用DeclareLaunchArgument。&lt;/p&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;&lt;code&gt;return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)])&lt;/code&gt;，这行代码有几个重要部分：&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;LaunchDescription：这是ROS 2 launch系统的核心类，它描述了要启动什么。它接受一个&lt;strong&gt;动作(actions)列表&lt;/strong&gt;作为参数。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;declared_arguments：这是我们之前创建的参数列表（虽然现在是空的）。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;OpaqueFunction(function=launch_setup)：这是一个特殊类型的action，它允许我们在launch过程中执行自定义的Python函数。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;launch_setup是我们自定义的函数，它包含了实际要执行的launch逻辑&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;使用OpaqueFunction的好处是我们可以在函数中编写复杂的逻辑，而不仅仅是在顶层写静态的launch描述，详细理解参考&lt;a href=&quot;https://zhuanlan.zhihu.com/p/671323402&quot;&gt;博客&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;generate_launch_description函数的工作流程&lt;/h2&gt;
&lt;p&gt;当运行ros2 launch ur_bringup simulation.launch.py时，以下步骤会依次发生：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;ROS 2找到ur_bringup包下的simulation.launch.py并执行generate_launch_description()函数&lt;/li&gt;
&lt;li&gt;generate_launch_description()返回一个LaunchDescription对象，其中包含一个OpaqueFunction&lt;/li&gt;
&lt;li&gt;ROS 2执行OpaqueFunction，这会导致调用launch_setup函数&lt;/li&gt;
&lt;li&gt;launch_setup函数执行实际的launch逻辑，定义要启动的所有节点和配置&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;补充一下什么是动作列表：&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;ROS 2 Launch系统中的动作(Action)
在ROS 2的launch系统中，&quot;动作列表&quot;(actions list)是指一系列要执行的操作或任务。这些操作可以是启动节点、设置参数、执行命令等。每个动作都是一个对象，描述了要执行的具体任务。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;动作(Action)是什么？
动作是ROS 2 launch系统的基本构建块。每个动作代表一个特定的操作，例如&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Node - 启动一个节点&lt;/li&gt;
&lt;li&gt;IncludeLaunchDescription - 包含并执行另一个launch文件&lt;/li&gt;
&lt;li&gt;DeclareLaunchArgument - 声明一个参数&lt;/li&gt;
&lt;li&gt;SetParameter - 设置参数值&lt;/li&gt;
&lt;li&gt;ExecuteProcess - 执行一个进程&lt;/li&gt;
&lt;li&gt;LogInfo - 输出日志信息&lt;/li&gt;
&lt;li&gt;OpaqueFunction - 执行自定义Python函数&lt;/li&gt;
&lt;li&gt;TimerAction - 延迟执行动作&lt;/li&gt;
&lt;li&gt;RegisterEventHandler - 注册事件处理器&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;launch_setup函数的工作流程&lt;/h2&gt;
&lt;p&gt;首先查看launch_setup函数代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def launch_setup(context, *args, **kwargs):
    &quot;&quot;&quot;
    启动设置函数，用于配置和启动机器人仿真环境
    
    该函数整合了多个子系统，包括机器人控制、MoveIt运动规划、
    深度图像处理、视觉检测和八叉树地图构建等模块
    
    参数:
        context: 启动上下文信息
        *args: 额外的位置参数
        **kwargs: 额外的关键字参数
    
    返回:
        list: 包含所有待启动节点的列表
    &quot;&quot;&quot;

    print(f&quot;\033[95mcontext:{context}\033[0m&quot;)
    ## 加载URDF模型并启动Gazebo仿真环境
    dual_ur5e_gripper_control_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;ur5e_gripper_moveit_config&quot;), &quot;/launch&quot;, &quot;/ur5e_gripper_sim_control.launch.py&quot;]
        ),
        launch_arguments={
            &quot;launch_rviz&quot;: &quot;true&quot;,
        }.items(),
    )

    ## 加载MoveIt配置
    dual_ur5e_gripper_moveit_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;ur5e_gripper_moveit_config&quot;), &quot;/launch&quot;, &quot;/ur5e_gripper_moveit.launch.py&quot;]
        ),
        launch_arguments={
            &quot;use_sim_time&quot;: &quot;true&quot;,
        }.items(),
    )

    ## 注册深度图像（与彩色图像对齐）
    register_depth_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;vision&quot;), &quot;/launch&quot;, &quot;/register_depth.launch.py&quot;]
        ),
    )

    ## 启动视觉处理模块
    vision_launch = IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;vision&quot;), &quot;/launch&quot;, &quot;/seg_and_det.launch.py&quot;]
        ),
    )
    
    ## 八叉树地图构建模块（当前被注释掉）
    octo_launch =  IncludeLaunchDescription(
        PythonLaunchDescriptionSource(
            [FindPackageShare(&quot;octo_bringup&quot;), &quot;/launch&quot;, &quot;/octomap.launch.py&quot;]
        ),
    )
    
    ## 组装所有需要启动的节点
    nodes_to_launch = [
        dual_ur5e_gripper_control_launch,
        dual_ur5e_gripper_moveit_launch,
        register_depth_launch,
        vision_launch,
        ## octo_launch,
    ]
    print(f&quot;\033[95mnodes_to_launch:{nodes_to_launch}\033[0m&quot;)
    return nodes_to_launch
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;通过执行代码打印信息如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-2%2F1.png%3ForigWidth%3D1536%26origHeight%3D122%26origFormat%3Dpng&amp;#x26;w=1536&amp;#x26;h=122&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;strong&gt;补充解释一下IncludeLaunchDescription和PythonLaunchDescriptionSource：&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;IncludeLaunchDescription：允许你在一个launch文件中嵌入或调用另一个launch文件&lt;/li&gt;
&lt;li&gt;PythonLaunchDescriptionSource：是IncludeLaunchDescription的一个来源类，它指定所包含的launch文件是Python格式的&lt;/li&gt;
&lt;li&gt;launch_arguments：如果被包含的launch文件有定义可选参数（使用DeclareLaunchArgument），你可以通过 launch_arguments 字典传递值。
.items() 是将字典转换为列表的键值对，用于满足函数要求。&lt;/li&gt;
&lt;li&gt;好处是支持模块化、代码复用、参数传递，适用于复杂项目中的启动管理。详解见&lt;a href=&quot;https://blog.csdn.net/m0_53297170/article/details/144539860&quot;&gt;博客&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;总体流程如下：执行&lt;code&gt;ros2 launch ur_bringup simulation.launch.py&lt;/code&gt;-&gt;找到simulation.launch.py的generate_launch_description函数并执行-&gt;通过OpaqueFunction执行自定义函数launch_setup（该函数返回动作列表）-&gt;launch_setup函数调用其他的launch文件。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>ROS2_Moveit2_Ur5e_Grasp项目详解（一）：项目文件介绍</title><link>https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/ros2_blogs/ros2_blogs-1</guid><description>基于ROS2的机械臂仿真抓取</description><pubDate>Tue, 03 Feb 2026 13:42:00 GMT</pubDate><content:encoded>&lt;p&gt;本系列详解的源代码来自github开源项目：&lt;a href=&quot;https://github.com/Nackustb/ros2_moveit2_ur5e_grasp.git&quot;&gt;github地址&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;声明：本项目详解仅用于个人学习ROS2，若有表达不对的地方，欢迎交流指正。&lt;/p&gt;
&lt;p&gt;关于ROS2的CMake介绍可以参考&lt;a href=&quot;https://zhuanlan.zhihu.com/p/354346905&quot;&gt;博客1&lt;/a&gt;，&lt;a href=&quot;https://zhuanlan.zhihu.com/p/542641709&quot;&gt;博客2&lt;/a&gt;，视频可以看&lt;a href=&quot;%E3%80%90ROS2_1.6.4.1C++%E7%BC%96%E8%AF%91%E5%B7%A5%E5%85%B7CMake%E5%92%8CCMakelist.txt%E3%80%91https://www.bilibili.com/video/BV1efkcYrE83?vd_source=cf91171db98d40fc8adf5e6af54cd9c9&quot;&gt;B站视频&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;ROS2的编译流程图解：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-1%2F1.png%3ForigWidth%3D1568%26origHeight%3D894%26origFormat%3Dpng&amp;#x26;w=1568&amp;#x26;h=894&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;学习本项目所需要的基础：ROS2、C++、Python、目标检测YOLO等基础&lt;/p&gt;
&lt;p&gt;项目介绍：UR5e Dynamic Grasping System 是一个基于 ROS2 和 MoveIt2 框架开发的智能抓取系统。该项目结合深度相机感知、目标定位、动态环境建图（OctoMap）与路径规划，实现了 UR5e机械臂在动态环境中的实时避障与抓取任务。 系统具备良好的扩展性和开源性，适用于机器人抓取、动态避障、人机协作等场景研究与开发。&lt;/p&gt;
&lt;p&gt;在 ROS 2 中，有三个重要的概念和目录：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-1%2F2.png%3ForigWidth%3D81%26origHeight%3D113%26origFormat%3Dpng&amp;#x26;w=81&amp;#x26;h=113&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;源码目录(src) - 这是你的源代码所在的地方，包括 Python 文件、C++ 文件、配置文件等（专注于开发）&lt;/li&gt;
&lt;li&gt;构建目录(build) - 这是在编译过程中创建的临时目录（编译过程的临时产物）&lt;/li&gt;
&lt;li&gt;安装目录(install) - 这是最终安装包的地方，也是运行时实际使用的目录（最终部署和运行时使用）&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;src源码目录结构&lt;/h2&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fros2_blogs%2Fros2_blogs-1%2F3.png%3ForigWidth%3D226%26origHeight%3D307%26origFormat%3Dpng&amp;#x26;w=226&amp;#x26;h=307&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;代码功能包说明&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;octo_bringup - OctoMap建图模块&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;负责实时生成环境的三维占据栅格地图&lt;/li&gt;
&lt;li&gt;用于动态避障&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;em&gt;注意：在本项目中这个并没有用到，暂时可以不进行了解&lt;/em&gt;&lt;/p&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;robotiq_description - Robotiq夹爪描述文件&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;包含Robotiq夹爪的详细描述文件&lt;/li&gt;
&lt;li&gt;包括网格模型、关节参数&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;robotiq_moveit_config - Robotiq的Moveit2配置包&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;包含SRDF 配置：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;定义了夹爪的运动学组（kinematic group）&lt;/li&gt;
&lt;li&gt;定义了夹爪的预设状态（如打开和关闭状态）&lt;/li&gt;
&lt;li&gt;定义了碰撞忽略规则（disable_collisions）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;包含夹爪组定义：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;创建了一个名为 &quot;gripper&quot; 的组，包含了夹爪的所有链接&lt;/li&gt;
&lt;li&gt;定义了两个预设状态：
&quot;open&quot; 状态：夹爪完全打开
&quot;close&quot; 状态：夹爪闭合到指定位置&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;包含碰撞忽略规则：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;定义了夹爪内部链接之间的碰撞忽略规则&lt;/li&gt;
&lt;li&gt;这些规则告诉 MoveIt 哪些链接之间永远不会发生碰撞，从而提高规划效率&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;em&gt;注意：这个包里的文件基本都是由Setup Assistant生成的。如果要快速上手的话，个人觉得不需要太深入的了解&lt;/em&gt;。&lt;/p&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;sim_models - 仿真模型资源包&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;包含Gazebo仿真用的模型文件&lt;/li&gt;
&lt;li&gt;包括UR5e机械臂、夹爪、深度相机等仿真模型&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;em&gt;注意：这些仿真模型，一般都是由产品对应的公司发布的&lt;/em&gt;。&lt;/p&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;ur_bringup - 系统启动包&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;包含启动整个系统的launch文件
&lt;ul&gt;
&lt;li&gt;simulation.launch.py 启动仿真环境、MoveIt配置、视觉模块和OctoMap&lt;/li&gt;
&lt;li&gt;start_grasp.launch.py 启动抓取演示&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;6&quot;&gt;
&lt;li&gt;ur5e_gripper_control - 机械臂控制包&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;实现了UR5e机械臂的控制逻辑&lt;/li&gt;
&lt;li&gt;包含抓取演示的主逻辑 demo.cpp&lt;/li&gt;
&lt;li&gt;核心控制类 ur5e_gripper.h 和 ur5e_gripper.cpp&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;7&quot;&gt;
&lt;li&gt;ur5e_gripper_description - UR5e机械臂和夹爪的URDF描述文件&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;包含机械臂和夹爪的3D模型描述&lt;/li&gt;
&lt;li&gt;用于机器人建模和可视化&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;8&quot;&gt;
&lt;li&gt;ur5e_gripper_moveit_config - MoveIt 2配置包&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;包含UR5e机械臂在MoveIt中的运动规划配置&lt;/li&gt;
&lt;li&gt;定义了规划组、末端执行器设置等&lt;/li&gt;
&lt;li&gt;包含运动学配置、规划算法配置、控制器配置等&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;9&quot;&gt;
&lt;li&gt;ur5e_octomap_moveit - OctoMap与MoveIt集成&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;将OctoMap生成的环境地图集成到MoveIt中&lt;/li&gt;
&lt;li&gt;实现基于环境感知的路径规划和避障&lt;/li&gt;
&lt;/ul&gt;
&lt;ol start=&quot;10&quot;&gt;
&lt;li&gt;vision - 视觉处理模块&lt;/li&gt;
&lt;/ol&gt;
&lt;ul&gt;
&lt;li&gt;负责目标物体检测与三维定位&lt;/li&gt;
&lt;li&gt;使用YOLOv11进行目标检测 (obj_detect.py)&lt;/li&gt;
&lt;li&gt;实现深度信息处理和目标追踪&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;系统工作流程&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;启动仿真环境:&lt;/p&gt;
&lt;p&gt;a. 通过 simulation.launch.py 启动Gazebo仿真环境
b. 加载UR5e机械臂和夹爪的URDF模型
c. 启动MoveIt 2运动规划框架
d. 启动视觉处理:&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;启动深度相机数据处理模块
a. 使用YOLOv11进行目标检测和定位
b. 将检测结果发布到ROS 2话题
c. 启动抓取演示:&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;通过 start_grasp.launch.py 启动抓取演示
a. 读取目标物体的位姿信息
b. 规划机械臂运动路径
c. 控制夹爪执行抓取动作
d. 将物体移动到指定位置并释放&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;核心技术组件&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;MoveIt 2 - 用于运动规划和控制&lt;/li&gt;
&lt;li&gt;Gazebo - 用于物理仿真&lt;/li&gt;
&lt;li&gt;YOLOv11 - 用于目标检测&lt;/li&gt;
&lt;li&gt;OctoMap - 用于环境建图和避障&lt;/li&gt;
&lt;li&gt;TF2 - 用于坐标变换&lt;/li&gt;
&lt;/ol&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/ros2_blogs/abstract.png?origWidth=935&amp;origHeight=438&amp;origFormat=png"/></item><item><title>Text2SQL（一）Vanna项目实践</title><link>https://astro-pure.js.org/blog/text2sql/text2sql_blogs-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/text2sql/text2sql_blogs-1</guid><description>记录Test2SQL学习的内容。</description><pubDate>Sat, 31 Jan 2026 15:59:00 GMT</pubDate><content:encoded>&lt;p&gt;Vanna 是一个基于 MIT 许可的开源 Python RAG（检索增强生成）框架，用于 SQL 生成和相关功能。它允许用户在数据上训练一个 RAG “模型”，然后提问问题，这将生成在数据库上运行的 SQL 查询语句，并将查询结果通过表格和图表的方式展示给用户。详细介绍参考&lt;a href=&quot;https://blog.csdn.net/sinat_29950703/article/details/136658639?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522d3e97fdf28fbd30c451f93fec22247aa%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&amp;#x26;request_id=d3e97fdf28fbd30c451f93fec22247aa&amp;#x26;biz_id=0&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-136658639-null-null.142%5Ev102%5Epc_search_result_base7&amp;#x26;utm_term=vanna&amp;#x26;spm=1018.2226.3001.4187&quot;&gt;博客&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;安装Mysql&lt;/h2&gt;
&lt;p&gt;我是用的是云容器，所以安装mysql稍微比较麻烦，如果是在本地，可以直接拉取mysql的镜像。下面介绍云容器安装mysql的过程：
&lt;strong&gt;安装mysql服务器&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;apt update
apt install mysql-server
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里我选择将mysql文件夹设置在我的mysql_data目录下并赋予权限：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sudo mkdir -p /root/shared-nvme/LLM-Learning/vanna/mysql_data
sudo chown mysql:mysql /root/shared-nvme/LLM-Learning/vanna/mysql_data
sudo chmod 750 /root/shared-nvme/LLM-Learning/vanna/mysql_data
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;初始化mysql&lt;/strong&gt;
这里跳过了密码生成：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sudo -u mysql mysqld --initialize-insecure --datadir=/root/shared-nvme/LLM-Learning/vanna/mysql_data
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;启动mysql服务&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sudo -u mysql /usr/sbin/mysqld \
  --datadir=/root/shared-nvme/LLM-Learning/vanna/mysql_data \
  --port=3306 \
  --socket=/root/shared-nvme/LLM-Learning/vanna/mysql_data/mysql.sock \
  --log-error=/root/shared-nvme/LLM-Learning/vanna/mysql_data/mysql.err &amp;#x26;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;检查端口&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;netstat -tlnp | grep 3306
## tcp        0      0 127.0.0.1:3306          0.0.0.0:*               LISTEN      65113/mysqld 
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到3306端口正在被mysqd监听，说明服务启动成功。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;连接mysql并设置密码&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;-- 连接到 MySQL
mysql -u root -h 127.0.0.1 -P 3306

-- 设置 root 密码
ALTER USER &apos;root&apos;@&apos;localhost&apos; IDENTIFIED BY &apos;your_root_password&apos;;

-- 刷新权限
FLUSH PRIVILEGES;

-- 退出
EXIT;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;设置密码后续连接mysql使用：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;mysql -u root -p -h 127.0.0.1 -P 3306
-- 创建demo数据库
create database demo;
-- 使用数据库
use demo;
-- 创建表
CREATE TABLE IF NOT EXISTS user (
        id INT PRIMARY KEY COMMENT &apos;用户ID&apos; ,
        name VARCHAR(100) COMMENT &apos;姓名&apos;,
        age INT COMMENT &apos;年龄&apos;
    ) COMMENT &apos;用户信息表&apos;;
    
insert into user values(1,&apos;迷糊老师&apos;,34),(2,&apos;菲菲公主&apos;,36),(3,&apos;小呆呆&apos;,24),(4,&apos;小猪猪&apos;,21),(5,&apos;超人强&apos;,18);
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;构建Vanna&lt;/h2&gt;
&lt;p&gt;首先自定义一个LLM，这里通过dashscope 使用云端的大模型，也可以使用本地的大模型（&lt;strong&gt;挖个坑后续补充&lt;/strong&gt;）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;import random 
from vanna.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from dashscope import Generation

DEBUG_INFO=None

class QwenLLM(VannaBase):
  def __init__(self,config=None):
    self.model=config[&apos;model&apos;]
    self.api_key=config[&apos;api_key&apos;]
  
  def system_message(self,message: str):
    return {&apos;role&apos;:&apos;system&apos;,&apos;content&apos;:message}

  def user_message(self, message: str):
    return {&apos;role&apos;:&apos;user&apos;,&apos;content&apos;:message}

  def assistant_message(self, message: str):
    return {&apos;role&apos;:&apos;assistant&apos;,&apos;content&apos;:message}
  
  def submit_prompt(self,prompt,**kwargs):
    resp=Generation.call(
      model=self.model,
      messages=prompt,
      seed=random.randint(1, 10000),
      result_format=&apos;message&apos;,
      api_key=self.api_key)
    answer=resp.output.choices[0].message.content
    global DEBUG_INFO
    DEBUG_INFO=(prompt,answer)
    return answer
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构建Vanna客户端：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class MyVanna(ChromaDB_VectorStore,QwenLLM):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self,config=config)
        QwenLLM.__init__(self,config=config)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;构造向量库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;vn.connect_to_mysql(host=&apos;localhost&apos;,dbname=&apos;demo&apos;,user=&apos;root&apos;,password=&apos;123456&apos;,port=3306)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码是建立 Vanna 客户端与 MySQL 数据库的连接。&lt;/p&gt;
&lt;p&gt;将DDL存储到向量库中：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;DDL=&apos;&apos;&apos;CREATE TABLE IF NOT EXISTS user (
        id INT PRIMARY KEY COMMENT &apos;用户ID&apos; ,
        name VARCHAR(100) COMMENT &apos;姓名&apos;,
        age INT COMMENT &apos;年龄&apos;
    ) COMMENT &apos;用户信息表&apos;;
&apos;&apos;&apos;
vn.train(ddl=DDL)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里需要介绍一下vanna的train和ask方法：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftext2sql%2Ftext2sql_blogs-1%2F1.png%3ForigWidth%3D2064%26origHeight%3D2066%26origFormat%3Dpng&amp;#x26;w=2064&amp;#x26;h=2066&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
其中train就是往向量库里存放内容，ask就是根据问题向向量库区取内容构建sql语句。&lt;/p&gt;
&lt;p&gt;运行&lt;code&gt;vn.train(ddl=DDL)&lt;/code&gt;之后会出现：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这表示ChromaDB 正在自动下载一个名为 all-MiniLM-L6-v2 的 ONNX 格式的嵌入模型（embedding model）。ChromaDB 是一个向量数据库，默认使用 sentence-transformers 模型 将文本转换为向量（embeddings）。第一次使用 ChromaDB 且没有指定自定义 embedding function 时，它会自动下载默认模型。&lt;/p&gt;
&lt;p&gt;存储DDL到向量库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;DDL=&apos;&apos;&apos;CREATE TABLE IF NOT EXISTS user (
        id INT PRIMARY KEY COMMENT &apos;用户ID&apos; ,
        name VARCHAR(100) COMMENT &apos;姓名&apos;,
        age INT COMMENT &apos;年龄&apos;
    ) COMMENT &apos;用户信息表&apos;;
&apos;&apos;&apos;

## 存储DDL到向量库
vn.train(ddl=DDL)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;返回：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## Adding ddl: CREATE TABLE IF NOT EXISTS user (
##         id INT PRIMARY KEY COMMENT &apos;用户ID&apos; ,
##         name VARCHAR(100) COMMENT &apos;姓名&apos;,
##         age INT COMMENT &apos;年龄&apos;
##     ) COMMENT &apos;用户信息表&apos;;
## &apos;ab0ac208-2f5e-50b0-9177-423427220940-ddl&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中&lt;code&gt;&apos;ab0ac208-2f5e-50b0-9177-423427220940-ddl&apos;&lt;/code&gt;是 Vanna 在内部用于标识和管理元数据（比如 DDL 语句）的一个 唯一键（key）。对 Vanna 内部的元数据管理和去重、更新、追踪来源等非常有用。&lt;/p&gt;
&lt;p&gt;同样的，存储document到向量库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## 存储document到向量库
vn.train(documentation=&apos;&quot;福报&quot;是指age&gt;=35岁，也就是可以向社会输送的人才&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;document经常用来存储行业内部的话术，方便模型理解，返回：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## Adding documentation....
## &apos;8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再存储SQL到向量库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## 存储SQL到向量库
&apos;&apos;&apos;
1，通过LLM根据SQL构造一个question
2，按question-SQL的JSON入库
            {
                &quot;question&quot;: question,
                &quot;sql&quot;: sql,
            }
&apos;&apos;&apos;

vn.train(sql=&apos;select name from user where age between 10 and 20&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;模型会根据问题生成一个question：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## Question generated with sql: Who are the users aged between 10 and 20? 
## Adding SQL...
## &apos;04a88b26-6984-5521-b897-73798ce0001f-sql&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到生成了问题：&lt;em&gt;Who are the users aged between 10 and 20?&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;我们在前面构建QwenLLM模型时，设置了全局变量DEBUG_INFO，现在我们来打印看上面这个过程发生了什么：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Q,A=DEBUG_INFO
print(&apos;PROMPT:&apos;,Q[0][&apos;content&apos;])
print(&apos;ANSWER:&apos;,A)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;返回：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;PROMPT: The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question.
ANSWER: Who are the users aged between 10 and 20?
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;看看Q的整体结构：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;PROMPT: [{&apos;role&apos;: &apos;system&apos;, &apos;content&apos;: &apos;The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question.&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;select name from user where age between 10 and 20&apos;}]
ANSWER: What are the names of users whose age is between 10 and 20?
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来存储question-SQL到向量库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## 存储question-SQL到向量库
&apos;&apos;&apos;
按question-SQL的JSON入库
            {
                &quot;question&quot;: question,
                &quot;sql&quot;: sql,
            }
&apos;&apos;&apos;
vn.train(question=&apos;小猪猪的年龄&apos;,sql=&apos;select age from user where name=&quot;小猪猪&quot;&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;返回：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## &apos;0189b3e3-c135-5bfe-a9f8-7faabd751813-sql&apos;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;查看所有入库的知识：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## 检查所有入库的知识
vn.get_training_data()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftext2sql%2Ftext2sql_blogs-1%2F2.png%3ForigWidth%3D1310%26origHeight%3D329%26origFormat%3Dpng&amp;#x26;w=1310&amp;#x26;h=329&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;开始查询&lt;/h2&gt;
&lt;p&gt;前面我们已经向向量库中输入了很多知识，接下来开始查询：
首先试试根据问题生成SQL语句：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## 基本使用
result=vn.generate_sql(&apos;用户的平均年龄&apos;)
print(&apos;SQL:&apos;,result)

Q,A=DEBUG_INFO
print(&apos;PROMPT:&apos;,Q[0][&apos;content&apos;])
print(&apos;ANSWER:&apos;,A)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;select avg(age) from user
SQL: select avg(age) from user
PROMPT: You are a SQL expert. 
===Tables 
CREATE TABLE IF NOT EXISTS user (
        id INT PRIMARY KEY COMMENT &apos;用户ID&apos; ,
        name VARCHAR(100) COMMENT &apos;姓名&apos;,
        age INT COMMENT &apos;年龄&apos;
    ) COMMENT &apos;用户信息表&apos;;



===Additional Context 

用户年龄段划分逻辑：0-10,10-20,20-30,30-40,40-50,50-60,60-70,70-80...左闭右开区间

&quot;福报&quot;是指age&gt;=35岁，也就是可以向社会输送的人才

===Response Guidelines 
1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. 
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql 
3. If the provided context is insufficient, please explain why it can&apos;t be generated. 
4. Please use the most relevant table(s). 
5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. 

ANSWER: select avg(age) from user
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;可以看到模型使用了之前传入的&lt;code&gt;Additional Context &lt;/code&gt;生成了正确的SQL语句&lt;code&gt; ANSWER: select avg(age) from user&lt;/code&gt;。实际上在这个过程中完整的对话应该如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;[{&apos;role&apos;: &apos;system&apos;, &apos;content&apos;: &apos;You are a SQL expert. \n===Tables \nCREATE TABLE IF NOT EXISTS user (\n        id INT PRIMARY KEY COMMENT \&apos;用户ID\&apos; ,\n        name VARCHAR(100) COMMENT \&apos;姓名\&apos;,\n        age INT COMMENT \&apos;年龄\&apos;\n    ) COMMENT \&apos;用户信息表\&apos;;\n\n\n\n===Additional Context \n\n&quot;福报&quot;是指age&gt;=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n3. If the provided context is insufficient, please explain why it can\&apos;t be generated. \n4. Please use the most relevant table(s). \n5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;小鱼儿的年龄&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select age from user where name=&quot;小鱼儿&quot;&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;小猪猪的年龄&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select age from user where name=&quot;小猪猪&quot;&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;用户的平均年龄&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select avg(age) from user&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;Who are the users aged between 10 and 20?&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select name from user where age between 10 and 20&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;打算给一批员工送福报，把他们的名字过滤出来&apos;}]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;模型会在后面将一些优秀的问答对传递给模型，例如&lt;code&gt;{&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;小猪猪的年龄&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select age from user where name=&quot;小猪猪&quot;&apos;}&lt;/code&gt;，使得模型能够理解并给出优秀的回答。这也就是论文中提到的第三种策略&lt;code&gt;Contextual&lt;/code&gt;：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftext2sql%2Ftext2sql_blogs-1%2F3.png%3ForigWidth%3D1528%26origHeight%3D628%26origFormat%3Dpng&amp;#x26;w=1528&amp;#x26;h=628&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
使用文本相关的&lt;code&gt;question-sql&lt;/code&gt;对，作为历史会话，从而提升模型效果。&lt;/p&gt;
&lt;p&gt;接下来直接向vanna提问，让它直接给出使用SQL语句执行后的结果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;vn.ask(&apos;用户的平均年龄&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;回答：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## (&apos;select avg(age) from user&apos;,
##   avg(age)
##  0  26.6000,
##  None)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;同样的，我们看看模型的思考过程：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;[{&apos;role&apos;: &apos;system&apos;, &apos;content&apos;: &apos;You are a SQL expert. \n===Tables \nCREATE TABLE IF NOT EXISTS user (\n        id INT PRIMARY KEY COMMENT \&apos;用户ID\&apos; ,\n        name VARCHAR(100) COMMENT \&apos;姓名\&apos;,\n        age INT COMMENT \&apos;年龄\&apos;\n    ) COMMENT \&apos;用户信息表\&apos;;\n\n\n\n===Additional Context \n\n&quot;福报&quot;是指age&gt;=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n3. If the provided context is insufficient, please explain why it can\&apos;t be generated. \n4. Please use the most relevant table(s). \n5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;小鱼儿的年龄&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select age from user where name=&quot;小鱼儿&quot;&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;小猪猪的年龄&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select age from user where name=&quot;小猪猪&quot;&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;Who are the users aged between 10 and 20?&apos;}, {&apos;role&apos;: &apos;assistant&apos;, &apos;content&apos;: &apos;select name from user where age between 10 and 20&apos;}, {&apos;role&apos;: &apos;user&apos;, &apos;content&apos;: &apos;用户的平均年龄&apos;}]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它会将年龄相关的&lt;code&gt;question-sql&lt;/code&gt;对，作为历史会话，从而提升模型效果。&lt;/p&gt;
&lt;p&gt;数据库里&lt;code&gt;users&lt;/code&gt;表里的内容如下，年龄均值正好是26.6岁：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;+----+--------------+------+
| id | name         | age  |
+----+--------------+------+
|  1 | 迷糊老师     |   34 |
|  2 | 菲菲公主     |   36 |
|  3 | 小呆呆       |   24 |
|  4 | 小猪猪       |   21 |
|  5 | 超人强       |   18 |
+----+--------------+------+
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再提问一个：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;vn.ask(&apos;打算给一批员工送福报，把他们的名字过滤出来&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;select name from user where age &gt;= 35    name
0  菲菲公主
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;好了，可怜的菲菲公主即将收到福报。&lt;/p&gt;
&lt;p&gt;在Vanna中执行ask后会生成图表，看看它的逻辑：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;Q,A=DEBUG_INFO
print(&apos;PROMPT:&apos;,Q[0][&apos;content&apos;])
print(&apos;ANSWER:&apos;,A)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;回答：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;PROMPT: The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: &apos;各个年龄段的人数都是多少？&apos;

The DataFrame was produced using this query: SELECT 
    CASE 
        WHEN age BETWEEN 0 AND 10 THEN &apos;0-10&apos;
        WHEN age BETWEEN 10 AND 20 THEN &apos;10-20&apos;
        WHEN age BETWEEN 20 AND 30 THEN &apos;20-30&apos;
        WHEN age BETWEEN 30 AND 40 THEN &apos;30-40&apos;
        WHEN age BETWEEN 40 AND 50 THEN &apos;40-50&apos;
        WHEN age BETWEEN 50 AND 60 THEN &apos;50-60&apos;
        WHEN age BETWEEN 60 AND 70 THEN &apos;60-70&apos;
        WHEN age BETWEEN 70 AND 80 THEN &apos;70-80&apos;
        ELSE &apos;80+&apos;
    END AS age_group,
    COUNT(*) AS count
FROM user
GROUP BY age_group;

The following is information about the resulting pandas DataFrame &apos;df&apos;: 
Running df.dtypes gives:
 age_group    object
count         int64
dtype: object
ANSWER: ```python
import plotly.express as px
import plotly.graph_objects as go

if len(df) == 1:
    fig = go.Figure(go.Indicator(
        mode = &quot;number&quot;,
        value = df[&apos;count&apos;].values[0],
        title = {&quot;text&quot;: f&quot;Age Group: {df[&apos;age_group&apos;].values[0]}&amp;#x3C;br&gt;Count&quot;}))
else:
    fig = px.bar(df, x=&apos;age_group&apos;, y=&apos;count&apos;, title=&apos;Number of People in Each Age Group&apos;)
fig.show()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它会将我们的问题和生成的SQL语句传入给大模型，再生成表格的代码。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/text2sql/text2sql_blogs-1/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/text2sql/text2sql_blogs-1/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>Text2SQL（二）Vanna源码解读之train和ask</title><link>https://astro-pure.js.org/blog/text2sql/text2sql_blogs-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/text2sql/text2sql_blogs-2</guid><description>记录Test2SQL学习的内容。</description><pubDate>Sat, 31 Jan 2026 15:59:00 GMT</pubDate><content:encoded>&lt;p&gt;本节源码基于&lt;a href=&quot;https://github.com/vanna-ai/vanna/tree/main&quot;&gt;官方&lt;/a&gt;。&lt;/p&gt;
&lt;p&gt;由上一章节可以知道Vanna主要由两个函数起作用，第一个是train函数，第二个是ask函数，这两个函数都封装在VannaBase类中。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;VannaBase 是Vanna框架的核心基类，它提供了一个完整的文本到SQL生成系统的架构。&lt;/li&gt;
&lt;li&gt;train 函数用于训练Vanna模型，使其能够更好地将自然语言问题转换为SQL查询。&lt;/li&gt;
&lt;li&gt;ask 函数是Vanna的主要交互接口，用于回答用户提出的自然语言问题。&lt;/li&gt;
&lt;/ul&gt;
&lt;h2&gt;train函数&lt;/h2&gt;
&lt;p&gt;源码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def train(
        self,
        question: str = None,
        sql: str = None,
        ddl: str = None,
        documentation: str = None,
        plan: TrainingPlan = None,
    ) -&gt; str:
        &quot;&quot;&quot;
        **Example:**
        ```python
        vn.train()
        ```

        Train Vanna.AI on a question and its corresponding SQL query.
        If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database.
        If you call it with the sql argument, it&apos;s equivalent to [`vn.add_question_sql()`][vanna.base.base.VannaBase.add_question_sql].
        If you call it with the ddl argument, it&apos;s equivalent to [`vn.add_ddl()`][vanna.base.base.VannaBase.add_ddl].
        If you call it with the documentation argument, it&apos;s equivalent to [`vn.add_documentation()`][vanna.base.base.VannaBase.add_documentation].
        Additionally, you can pass a [`TrainingPlan`][vanna.types.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_generic()`][vanna.base.base.VannaBase.get_training_plan_generic].

        Args:
            question (str): The question to train on.
            sql (str): The SQL query to train on.
            ddl (str):  The DDL statement.
            documentation (str): The documentation to train on.
            plan (TrainingPlan): The training plan to train on.
        &quot;&quot;&quot;

        if question and not sql:
            raise ValidationError(&quot;Please also provide a SQL query&quot;)

        if documentation:
            print(&quot;Adding documentation....&quot;)
            return self.add_documentation(documentation)

        if sql:
            if question is None:
                question = self.generate_question(sql)
                print(&quot;Question generated with sql:&quot;, question, &quot;\nAdding SQL...&quot;)
            return self.add_question_sql(question=question, sql=sql)

        if ddl:
            print(&quot;Adding ddl:&quot;, ddl)
            return self.add_ddl(ddl)

        if plan:
            for item in plan._plan:
                if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
                    self.add_ddl(item.item_value)
                elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
                    self.add_documentation(item.item_value)
                elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
                    self.add_question_sql(question=item.item_name, sql=item.item_value)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;首先是方法定义部分：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def train(
        self,
        question: str = None,
        sql: str = None,
        ddl: str = None,
        documentation: str = None,
        plan: TrainingPlan = None,
    ) -&gt; str:
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;表示方法接收&lt;code&gt;question&lt;/code&gt;、&lt;code&gt;sql&lt;/code&gt;、&lt;code&gt;ddl&lt;/code&gt;、&lt;code&gt;documentation&lt;/code&gt;和&lt;code&gt;plan&lt;/code&gt;，并返回一个字符串结果。如果不带参数调用，它会检查是否连接到数据库，并尝试在该数据库的元数据上进行训练。 如果使用sql参数调用，它等同于[vn.add_question_sql()][vanna.base.base.VannaBase.add_question_sql]。 如果使用ddl参数调用，它等同于[vn.add_ddl()][vanna.base.base.VannaBase.add_ddl]。 如果使用documentation参数调用，它等同于[vn.add_documentation()][vanna.base.base.VannaBase.add_documentation]。 此外，您可以传递一个[TrainingPlan][vanna.types.TrainingPlan]对象。使用[vn.get_training_plan_generic()][vanna.base.base.VannaBase.get_training_plan_generic]获取训练计划。&lt;/p&gt;
&lt;p&gt;结合源码可以知道：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if ddl:
   print(&quot;Adding ddl:&quot;, ddl)
   return self.add_ddl(ddl)
if question and not sql:
   raise ValidationError(&quot;Please also provide a SQL query&quot;)
if documentation:
     print(&quot;Adding documentation....&quot;)
     return self.add_documentation(documentation)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这些都比较简单就是简单的方法调用，不过这些都是抽象方法需要具体实现：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -&gt; str:
    &quot;&quot;&quot;
    This method is used to add a DDL statement to the training data.

    Args:
        ddl (str): The DDL statement to add.

    Returns:
        str: The ID of the training data that was added.
    &quot;&quot;&quot;
    pass

@abstractmethod
def add_documentation(self, documentation: str, **kwargs) -&gt; str:
    &quot;&quot;&quot;
    This method is used to add documentation to the training data.

    Args:
        documentation (str): The documentation to add.

    Returns:
        str: The ID of the training data that was added.
    &quot;&quot;&quot;
    pass
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;下面来看看：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if sql:
   if question is None:
       question = self.generate_question(sql)
       print(&quot;Question generated with sql:&quot;, question, &quot;\nAdding SQL...&quot;)
   return self.add_question_sql(question=question, sql=sql)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;当传入SQL语句，没有传入Question时，Vanna会自动根据SQL生成问题，效果如下：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftext2sql%2Ftext2sql_blogs-2%2F1.png%3ForigWidth%3D1200%26origHeight%3D158%26origFormat%3Dpng&amp;#x26;w=1200&amp;#x26;h=158&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
深入看一下&lt;code&gt;generate_question&lt;/code&gt;方法：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def generate_question(self, sql: str, **kwargs) -&gt; str:
     response = self.submit_prompt(
         [
             self.system_message(
                 &quot;The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question.&quot;
             ),
             self.user_message(sql),
         ],
         **kwargs,
     )

     return response
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;方法很简单，就是构造一个系统提示词和用户提供的SQL语句，然后传给抽象方法&lt;code&gt;submit_prompt&lt;/code&gt;，这个方法是由我们自己实现的，在QwenLLM类中：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def submit_prompt(self,prompt,**kwargs):
    resp=Generation.call(
      model=self.model,
      messages=prompt,
      seed=random.randint(1, 10000),
      result_format=&apos;message&apos;,
      api_key=self.api_key)
    answer=resp.output.choices[0].message.content
    global DEBUG_INFO
    DEBUG_INFO=(prompt,answer)
    return answer
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;submit_prompt&lt;/code&gt;方法收到了&lt;code&gt;prompt&lt;/code&gt;，然后调用模型进行推理解析并返回&lt;code&gt;answer&lt;/code&gt;。此时，我们已经通过SQL生成了Question，然后调用&lt;code&gt;add_question_sql(question=question, sql=sql)&lt;/code&gt;方法，这个方法是将Question-SQL对添加到训练数据中，返回唯一的ID，这个抽象方法需要我们具体实现，保存的训练数据格式是一个json列表，每个json是一个训练样本，例如：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;{
      &quot;question&quot;:&quot;what are 5 most grossing movies in IMDB top 1000 &quot;,
      &quot;answer&quot;:&quot;SELECT series_title,\n       gross\nFROM   imdb.public.movies\nORDER BY gross desc limit 5;&quot;
    }
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;接下来看&lt;code&gt;train&lt;/code&gt;方法中比较复杂的一部分：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if plan:
   for item in plan._plan:
       if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
           self.add_ddl(item.item_value)
       elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
           self.add_documentation(item.item_value)
       elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
           self.add_question_sql(question=item.item_name, sql=item.item_value)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;传入的&lt;code&gt;plan&lt;/code&gt;是一个&lt;code&gt;TrainingPlan&lt;/code&gt;类的实例，TrainingPlan类是Vanna框架中用于管理和组织训练数据的核心类，它代表了一个结构化的训练计划。其中&lt;code&gt;_plan&lt;/code&gt;是一个&lt;code&gt;TrainingPlanItem&lt;/code&gt;的列表，&lt;code&gt;TrainingPlanItem&lt;/code&gt;包括以下属性：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class TrainingPlanItem:
    item_type: str
    item_group: str
    item_name: str
    item_value: str

    def __str__(self):
        if self.item_type == self.ITEM_TYPE_SQL:
            return f&quot;Train on SQL: {self.item_group} {self.item_name}&quot;
        elif self.item_type == self.ITEM_TYPE_DDL:
            return f&quot;Train on DDL: {self.item_group} {self.item_name}&quot;
        elif self.item_type == self.ITEM_TYPE_IS:
            return f&quot;Train on Information Schema: {self.item_group} {self.item_name}&quot;

    ITEM_TYPE_SQL = &quot;sql&quot;
    ITEM_TYPE_DDL = &quot;ddl&quot;
    ITEM_TYPE_IS = &quot;is&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;item_type: 训练项类型（SQL查询、DDL语句、信息模式）&lt;/li&gt;
&lt;li&gt;item_group: 训练项分组（如数据库名.模式名）&lt;/li&gt;
&lt;li&gt;item_name: 训练项名称（如表名）&lt;/li&gt;
&lt;li&gt;item_value: 训练项具体内容&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;以&lt;code&gt;item.item_type==sql&lt;/code&gt;为例，将它的值执行抽象方法&lt;code&gt;add_ddl&lt;/code&gt;，并返回唯一的ID。执行完&lt;code&gt;train&lt;/code&gt;方法的这一段代码后，Vanna AI模型将会：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;学习到数据库的表结构(通过DDL)&lt;/li&gt;
&lt;li&gt;获得额外的上下文信息(通过文档)&lt;/li&gt;
&lt;li&gt;掌握更多问题与SQL查询的对应关系(通过问答对)&lt;/li&gt;
&lt;li&gt;提升将自然语言转换为SQL查询的准确率&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这实际上是批量训练模型的过程，将训练计划中所有类型的训练数据都添加到模型的检索层中，以增强模型的性能。&lt;code&gt;add_ddl&lt;/code&gt;、&lt;code&gt;add_documentation&lt;/code&gt;和&lt;code&gt;add_question_sql&lt;/code&gt;需要在子类中具体实现，比如使用ChromaDB作为向量存储时，会在vanna.chromadb_vector.ChromaDB_VectorStore类中实现。&lt;/p&gt;
&lt;p&gt;训练好的数据在生成SQL时会被检索和使用：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;get_similar_question_sql()：检索相似的问题-SQL对
get_related_ddl()：检索相关的DDL语句
get_related_documentation()：检索相关的文档
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这些检索到的信息会被组合成提示词，提供给大语言模型生成最终的SQL查询。&lt;/p&gt;
&lt;p&gt;非常好，Vanna提供了该子类的实现代码，接下来我们需要在&lt;code&gt;ChromaDB_VectorStore&lt;/code&gt;学习。&lt;/p&gt;
&lt;h3&gt;ChromaDB_VectorStore类&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class ChromaDB_VectorStore(VannaBase):
    def __init__(self, config=None):
        VannaBase.__init__(self, config=config)
        if config is None:
            config = {}

        path = config.get(&quot;path&quot;, &quot;.&quot;)
        self.embedding_function = config.get(&quot;embedding_function&quot;, default_ef)
        curr_client = config.get(&quot;client&quot;, &quot;persistent&quot;)
        collection_metadata = config.get(&quot;collection_metadata&quot;, None)
        self.n_results_sql = config.get(&quot;n_results_sql&quot;, config.get(&quot;n_results&quot;, 10))
        self.n_results_documentation = config.get(&quot;n_results_documentation&quot;, config.get(&quot;n_results&quot;, 10))
        self.n_results_ddl = config.get(&quot;n_results_ddl&quot;, config.get(&quot;n_results&quot;, 10))

        if curr_client == &quot;persistent&quot;:
            self.chroma_client = chromadb.PersistentClient(
                path=path, settings=Settings(anonymized_telemetry=False)
            )
        elif curr_client == &quot;in-memory&quot;:
            self.chroma_client = chromadb.EphemeralClient(
                settings=Settings(anonymized_telemetry=False)
            )
        elif isinstance(curr_client, chromadb.api.client.Client):
            ## allow providing client directly
            self.chroma_client = curr_client
        else:
            raise ValueError(f&quot;Unsupported client was set in config: {curr_client}&quot;)

        self.documentation_collection = self.chroma_client.get_or_create_collection(
            name=&quot;documentation&quot;,
            embedding_function=self.embedding_function,
            metadata=collection_metadata,
        )
        self.ddl_collection = self.chroma_client.get_or_create_collection(
            name=&quot;ddl&quot;,
            embedding_function=self.embedding_function,
            metadata=collection_metadata,
        )
        self.sql_collection = self.chroma_client.get_or_create_collection(
            name=&quot;sql&quot;,
            embedding_function=self.embedding_function,
            metadata=collection_metadata,
        )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;ChromaDB_VectorStore&lt;/code&gt;继承&lt;code&gt;VannaBase&lt;/code&gt;类，实现VannaBase类提供的一些抽象方法：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;path = config.get(&quot;path&quot;, &quot;.&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;path: ChromaDB 数据持久化存储的路径。默认值为当前目录(&quot;.&quot;)，即在当前目录下创建和存储向量数据库。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;self.embedding_function = config.get(&quot;embedding_function&quot;, default_ef)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;embedding_function: 用于将文本转换为向量的嵌入函数。默认使用 DefaultEmbeddingFunction。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;curr_client = config.get(&quot;client&quot;, &quot;persistent&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;client: ChromaDB 客户端类型。&quot;persistent&quot;: 持久化客户端，数据存储在磁盘上（默认）。&quot;in-memory&quot;: 内存客户端，数据仅在内存中，程序退出后丢失。chromadb.api.client.Client: 直接提供已配置的客户端实例。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;collection_metadata = config.get(&quot;collection_metadata&quot;, None)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;collection_metadata: 集合的元数据信息。用于存储集合的额外信息，如版本号、创建时间等。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;self.n_results_sql = config.get(&quot;n_results_sql&quot;, config.get(&quot;n_results&quot;, 10))
self.n_results_documentation = config.get(&quot;n_results_documentation&quot;, config.get(&quot;n_results&quot;, 10))
self.n_results_ddl = config.get(&quot;n_results_ddl&quot;, config.get(&quot;n_results&quot;, 10))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;查询结果数量配置：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;n_results_sql: 查询相似 SQL 问题时返回的结果数量，默认10个&lt;/li&gt;
&lt;li&gt;n_results_documentation: 查询相关文档时返回的结果数量，默认10个&lt;/li&gt;
&lt;li&gt;n_results_ddl: 查询相关 DDL 语句时返回的结果数量，默认10个&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if curr_client == &quot;persistent&quot;:
    self.chroma_client = chromadb.PersistentClient(
        path=path, settings=Settings(anonymized_telemetry=False)
    )
elif curr_client == &quot;in-memory&quot;:
    self.chroma_client = chromadb.EphemeralClient(
        settings=Settings(anonymized_telemetry=False)
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;根据配置创建相应的 ChromaDB 客户端。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;self.documentation_collection = self.chroma_client.get_or_create_collection(
    name=&quot;documentation&quot;,
    embedding_function=self.embedding_function,
    metadata=collection_metadata,
)
## 类似地创建 ddl_collection 和 sql_collection
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;创建三个向量集合（Collections）用于存储不同类型的数据：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;documentation collection: 存储文档信息&lt;/li&gt;
&lt;li&gt;ddl collection: 存储数据定义语言（表结构等）&lt;/li&gt;
&lt;li&gt;sql collection: 存储问题和 SQL 查询对&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;下面还是以&lt;code&gt;add_ddl&lt;/code&gt;方法为例，由&lt;code&gt;train&lt;/code&gt;函数中的&lt;code&gt;self.add_ddl(item.item_value)&lt;/code&gt;将值传入进来，然后经过：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def add_ddl(self, ddl: str, **kwargs) -&gt; str:
        id = deterministic_uuid(ddl) + &quot;-ddl&quot;
        self.ddl_collection.add(
            documents=ddl,
            embeddings=self.generate_embedding(ddl),
            ids=id,
        )
        return id
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;首先生成唯一的ID，然后收集&lt;code&gt;documents&lt;/code&gt;、对应的&lt;code&gt;embedding&lt;/code&gt;结果和ID。同样的收集&lt;code&gt;question_sql&lt;/code&gt;和&lt;code&gt;documentation&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;执行&lt;code&gt;get_training_data&lt;/code&gt;方法的结果如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;                                         id  \
0  04a88b26-6984-5521-b897-73798ce0001f-sql   
1  e5102160-2dbf-5300-98f5-24d762a12b59-sql   
2  0189b3e3-c135-5bfe-a9f8-7faabd751813-sql   
3  eb6bbff7-a89c-51bc-a58d-ebf6dc181ae3-sql   
4  54db6ffd-201b-59a2-8568-cd05d82db461-sql   
5  9ccf7bcd-5091-5b97-bf72-af9d41e526a5-sql   
6  270bfd96-c340-5b21-afe9-0d14c23fd8bd-sql   
0  ab0ac208-2f5e-50b0-9177-423427220940-ddl   
0  8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc   
1  82e9153e-0b3b-5aca-ac66-31e65eb61d36-doc   

                                            question  \
0          Who are the users aged between 10 and 20?   
1                                             小鱼儿的年龄   
2                                             小猪猪的年龄   
3                                            用户的平均年龄   
4                              打算给一批员工送福报，把他们的名字过滤出来   
5  What are the names of users whose age is betwe...   
6                                      各个年龄段的人数都是多少？   
0                                               None   
0                                               None   
1                                               None   

                                             content training_data_type  
0  select name from user where age between 10 and 20                sql  
1              select age from user where name=&quot;小鱼儿&quot;                sql  
2              select age from user where name=&quot;小猪猪&quot;                sql  
3                          select avg(age) from user                sql  
4              select name from user where age &gt;= 35                sql  
5  select name from user where age between 10 and 20                sql  
6  SELECT \n    CASE \n        WHEN age BETWEEN 0...                sql  
0  CREATE TABLE IF NOT EXISTS user (\n        id ...                ddl  
0                       &quot;福报&quot;是指age&gt;=35岁，也就是可以向社会输送的人才      documentation  
1  用户年龄段划分逻辑：0-10,10-20,20-30,30-40,40-50,50-60,6...      documentation  
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;返回的内容是一个表格，可以看到包括四个属性&lt;code&gt;id&lt;/code&gt; 、&lt;code&gt;question&lt;/code&gt;、&lt;code&gt;content &lt;/code&gt;和&lt;code&gt;training_data_type &lt;/code&gt;，值得注意的是，对于&lt;code&gt;ddl&lt;/code&gt;和&lt;code&gt;documentation&lt;/code&gt;类型的训练数据，是没有&lt;code&gt;question&lt;/code&gt;值的。&lt;/p&gt;
&lt;p&gt;在&lt;code&gt;ChromaDB_VectorStore&lt;/code&gt;类的方法里还有&lt;code&gt;remove_training_data&lt;/code&gt;和&lt;code&gt;remove_collection&lt;/code&gt;方法，可以删除数据。&lt;/p&gt;
&lt;p&gt;具体的训练流程就是，当在&lt;code&gt;VannaBase&lt;/code&gt;基类的&lt;code&gt;train&lt;/code&gt;方法中设置&lt;code&gt;plan&lt;/code&gt;为True时，它会将接收到新的&lt;code&gt;ddl&lt;/code&gt;、&lt;code&gt;documentation&lt;/code&gt;和&lt;code&gt;question-sql&lt;/code&gt;对，以及它们对应的&lt;code&gt;embedding&lt;/code&gt;值存储到向量库中。如果提供&lt;code&gt;json&lt;/code&gt;格式的训练集，就只需要对该训练集进行遍历存储到向量库中即可。&lt;/p&gt;
&lt;h2&gt;ask函数&lt;/h2&gt;
&lt;p&gt;接下来分析另一个重要的函数ask函数，首先看它的初始化：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def ask(
        self,
        question: Union[str, None] = None,
        print_results: bool = True,
        auto_train: bool = True,
        visualize: bool = True,  ## if False, will not generate plotly code
        allow_llm_to_see_data: bool = False,
    ) -&gt; Union[
        Tuple[
            Union[str, None],
            Union[pd.DataFrame, None],
            Union[plotly.graph_objs.Figure, None],
        ],
        None,
    ]:
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;question (Union[str, None]) - 用户要询问的问题字符串，如果为 None 则会提示用户输入&lt;/li&gt;
&lt;li&gt;print_results (bool) - 是否打印结果，默认为 True&lt;/li&gt;
&lt;li&gt;auto_train (bool) - 是否自动训练，默认为 True，会将问题和 SQL 查询对添加到训练数据中&lt;/li&gt;
&lt;li&gt;visualize (bool) - 是否生成图表，默认为 True，会根据数据生成 Plotly 图表&lt;/li&gt;
&lt;li&gt;allow_llm_to_see_data (bool) - 是否允许 LLM 查看数据，默认为 False&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;具体返回一个三元组 (Tuple) 或 None：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;SQL 查询字符串 (str) - 生成的 SQL 查询语句&lt;/li&gt;
&lt;li&gt;数据结果 (pd.DataFrame) - SQL 查询执行后的结果数据&lt;/li&gt;
&lt;li&gt;图表对象 (plotly.graph_objs.Figure) - 根据数据生成的可视化图表&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;try:
    sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
except Exception as e:
    print(e)
    return None, None, None
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;根据用户的提问生成sql语句，并确定是否让LLM看到数据。查看&lt;code&gt;generate_sql&lt;/code&gt;函数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if self.config is not None:
    initial_prompt = self.config.get(&quot;initial_prompt&quot;, None)
else:
    initial_prompt = None
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_sql_prompt(
    initial_prompt=initial_prompt,
    question=question,
    question_sql_list=question_sql_list,
    ddl_list=ddl_list,
    doc_list=doc_list,
    **kwargs,
)
self.log(title=&quot;SQL Prompt&quot;, message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title=&quot;LLM Response&quot;, message=llm_response)

if &apos;intermediate_sql&apos; in llm_response:
    if not allow_llm_to_see_data:
        return &quot;The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this.&quot;

    if allow_llm_to_see_data:
        intermediate_sql = self.extract_sql(llm_response)

        try:
            self.log(title=&quot;Running Intermediate SQL&quot;, message=intermediate_sql)
            df = self.run_sql(intermediate_sql)

            prompt = self.get_sql_prompt(
                initial_prompt=initial_prompt,
                question=question,
                question_sql_list=question_sql_list,
                ddl_list=ddl_list,
                doc_list=doc_list+[f&quot;The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n&quot; + df.to_markdown()],
                **kwargs,
            )
            self.log(title=&quot;Final SQL Prompt&quot;, message=prompt)
            llm_response = self.submit_prompt(prompt, **kwargs)
            self.log(title=&quot;LLM Response&quot;, message=llm_response)
        except Exception as e:
            return f&quot;Error running intermediate SQL: {e}&quot;
return self.extract_sql(llm_response)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;调用三个关键方法获取生成 SQL 所需的上下文：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;get_similar_question_sql()：获取相似的问题-SQL 对&lt;/li&gt;
&lt;li&gt;get_related_ddl()：获取相关的数据定义语言（表结构）&lt;/li&gt;
&lt;li&gt;get_related_documentation()：获取相关的文档说明&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这三个方法在子类&lt;code&gt;ChromaDB_VectorStore&lt;/code&gt;中实现，不同的向量数据库有不同的实现，这里简单看&lt;code&gt;get_similar_question_sql&lt;/code&gt;方法的具体实现：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def get_similar_question_sql(self, question: str, **kwargs) -&gt; list:
        return ChromaDB_VectorStore._extract_documents(
            self.sql_collection.query(
                query_texts=[question],
                n_results=self.n_results_sql,
            )
        )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;它是将&lt;code&gt;sql_collection.query&lt;/code&gt;的结果传入&lt;code&gt;_extract_documents&lt;/code&gt;方法。&lt;strong&gt;sql_collection.query方法是如何实现的？这里挖个坑。&lt;/strong&gt;
&lt;code&gt;_extract_documents&lt;/code&gt;方法将 ChromaDB 查询返回的原始数据结构转换为可以直接使用的文档列表。主要用于处理以下三种查询的返回结果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;## 1. 获取相似问题-SQL对
self.get_similar_question_sql(question) 
## 返回: [{&quot;question&quot;: &quot;...&quot;, &quot;sql&quot;: &quot;...&quot;}, {...}]

## 2. 获取相关DDL语句
self.get_related_ddl(question)
## 返回: [&quot;CREATE TABLE ...&quot;, &quot;CREATE TABLE ...&quot;]

## 3. 获取相关文档
self.get_related_documentation(question)
## 返回: [&quot;文档内容1&quot;, &quot;文档内容2&quot;]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;保存为&lt;code&gt;question_sql_list&lt;/code&gt;，同样的还有&lt;code&gt;ddl_list&lt;/code&gt;和&lt;code&gt;doc_list&lt;/code&gt;，然后一起组装成&lt;code&gt;prompt&lt;/code&gt;:&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;prompt = self.get_sql_prompt(
         initial_prompt=initial_prompt,
         question=question,
         question_sql_list=question_sql_list,
         ddl_list=ddl_list,
         doc_list=doc_list,
         **kwargs,
     )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;具体实现如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def get_sql_prompt(
    self,
    initial_prompt : str,
    question: str,
    question_sql_list: list,
    ddl_list: list,
    doc_list: list,
    **kwargs,
):
    &quot;&quot;&quot;
    Example:
    ```python
    vn.get_sql_prompt(
        question=&quot;What are the top 10 customers by sales?&quot;,
        question_sql_list=[{&quot;question&quot;: &quot;What are the top 10 customers by sales?&quot;, &quot;sql&quot;: &quot;SELECT * FROM customers ORDER BY sales DESC LIMIT 10&quot;}],
        ddl_list=[&quot;CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)&quot;],
        doc_list=[&quot;The customers table contains information about customers and their sales.&quot;],
    )

    ```

    This method is used to generate a prompt for the LLM to generate SQL.

    Args:
        question (str): The question to generate SQL for.
        question_sql_list (list): A list of questions and their corresponding SQL statements.
        ddl_list (list): A list of DDL statements.
        doc_list (list): A list of documentation.

    Returns:
        any: The prompt for the LLM to generate SQL.
    &quot;&quot;&quot;

    if initial_prompt is None:
        initial_prompt = f&quot;You are a {self.dialect} expert. &quot; + \
        &quot;Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. &quot;

    initial_prompt = self.add_ddl_to_prompt(
        initial_prompt, ddl_list, max_tokens=self.max_tokens
    )

    if self.static_documentation != &quot;&quot;:
        doc_list.append(self.static_documentation)

    initial_prompt = self.add_documentation_to_prompt(
        initial_prompt, doc_list, max_tokens=self.max_tokens
    )

    initial_prompt += (
        &quot;===Response Guidelines \n&quot;
        &quot;1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n&quot;
        &quot;2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n&quot;
        &quot;3. If the provided context is insufficient, please explain why it can&apos;t be generated. \n&quot;
        &quot;4. Please use the most relevant table(s). \n&quot;
        &quot;5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n&quot;
        f&quot;6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n&quot;
    )

    message_log = [self.system_message(initial_prompt)]

    for example in question_sql_list:
        if example is None:
            print(&quot;example is None&quot;)
        else:
            if example is not None and &quot;question&quot; in example and &quot;sql&quot; in example:
                message_log.append(self.user_message(example[&quot;question&quot;]))
                message_log.append(self.assistant_message(example[&quot;sql&quot;]))

    message_log.append(self.user_message(question))

    return message_log
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;然后将生成的prompt提交&lt;code&gt;llm_response = self.submit_prompt(prompt, **kwargs)&lt;/code&gt;到大模型得到返回结果，继续执行代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if &apos;intermediate_sql&apos; in llm_response:
   if not allow_llm_to_see_data:
       return &quot;The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this.&quot;

   if allow_llm_to_see_data:
       intermediate_sql = self.extract_sql(llm_response)

       try:
           self.log(title=&quot;Running Intermediate SQL&quot;, message=intermediate_sql)
           df = self.run_sql(intermediate_sql)

           prompt = self.get_sql_prompt(
               initial_prompt=initial_prompt,
               question=question,
               question_sql_list=question_sql_list,
               ddl_list=ddl_list,
               doc_list=doc_list+[f&quot;The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n&quot; + df.to_markdown()],
               **kwargs,
           )
           self.log(title=&quot;Final SQL Prompt&quot;, message=prompt)
           llm_response = self.submit_prompt(prompt, **kwargs)
           self.log(title=&quot;LLM Response&quot;, message=llm_response)
       except Exception as e:
           return f&quot;Error running intermediate SQL: {e}&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;由前面生成的&lt;code&gt;prompt&lt;/code&gt;的&lt;code&gt;Response Guidelines&lt;/code&gt;的第2条可以看到，当提交上下文的内容不是很充分时，需要执行一个中间的查询操作，也就是提示词里会出现&lt;code&gt;intermediate&lt;/code&gt;，然后执行&lt;code&gt;extract_sql&lt;/code&gt;方法，这个方法就是当 LLM 生成响应时，通常不仅包含 SQL 查询，还可能包含解释、注释或其他文本。这个方法的作用就是从复杂的响应中准确提取出真正的 SQL 查询语句：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;def extract_sql(self, llm_response: str) -&gt; str:
	  import re
	  ## Match CREATE TABLE ... AS SELECT
    sqls = re.findall(r&quot;\bCREATE\s+TABLE\b.*?\bAS\b.*?;&quot;, llm_response, re.DOTALL | re.IGNORECASE)
    if sqls:
        sql = sqls[-1]
        self.log(title=&quot;Extracted SQL&quot;, message=f&quot;{sql}&quot;)
        return sql

    ## Match WITH clause (CTEs)
    sqls = re.findall(r&quot;\bWITH\b .*?;&quot;, llm_response, re.DOTALL | re.IGNORECASE)
    if sqls:
        sql = sqls[-1]
        self.log(title=&quot;Extracted SQL&quot;, message=f&quot;{sql}&quot;)
        return sql

    ## Match SELECT ... ;
    sqls = re.findall(r&quot;\bSELECT\b .*?;&quot;, llm_response, re.DOTALL | re.IGNORECASE)
    if sqls:
        sql = sqls[-1]
        self.log(title=&quot;Extracted SQL&quot;, message=f&quot;{sql}&quot;)
        return sql

    ## Match ```sql ... ```blocks
    sqls = re.findall(r&quot;```sql\s*\n(.*?)```&quot;, llm_response, re.DOTALL | re.IGNORECASE)
    if sqls:
        sql = sqls[-1].strip()
        self.log(title=&quot;Extracted SQL&quot;, message=f&quot;{sql}&quot;)
        return sql

    ## Match any ```... ```code blocks
    sqls = re.findall(r&quot;```(.*?)```&quot;, llm_response, re.DOTALL | re.IGNORECASE)
    if sqls:
        sql = sqls[-1].strip()
        self.log(title=&quot;Extracted SQL&quot;, message=f&quot;{sql}&quot;)
        return sql

    return llm_response
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;他是使用正则化找到LLM返回内容中的SQL语句，例如大模型返回内容：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;根据您的问题，我建议使用以下 SQL 查询：

```sql
SELECT customer_name, SUM(sales) as total_sales 
FROM customers 
GROUP BY customer_name 
ORDER BY total_sales DESC 
LIMIT 10;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;经过&lt;code&gt;extract_sql&lt;/code&gt;处理后：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;SELECT customer_name, SUM(sales) as total_sales FROM customers GROUP BY customer_name ORDER BY total_sales DESC LIMIT 10;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;再执行中间SQL&lt;code&gt;df = self.run_sql(intermediate_sql)&lt;/code&gt;得到数据表然后执行：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;prompt = self.get_sql_prompt(
        initial_prompt=initial_prompt,
        question=question,
        question_sql_list=question_sql_list,
        ddl_list=ddl_list,
        doc_list=doc_list+[f&quot;The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n&quot; + df.to_markdown()],
        **kwargs,
    )
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这个操作就是将表的信息放到&lt;code&gt;doc_list&lt;/code&gt;中，生成&lt;code&gt;prompt&lt;/code&gt;，这样就弥补了之前信息不充分的缺点。最后再提交一次给大模型，提取sql返回即可：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;llm_response = self.submit_prompt(prompt, **kwargs)
return self.extract_sql(llm_response)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;上面就是&lt;code&gt;ask&lt;/code&gt;方法中的&lt;code&gt;generate_sql&lt;/code&gt;的逻辑，总结就是，根据用户额度提问，从&lt;code&gt;DDL&lt;/code&gt;、&lt;code&gt;Document&lt;/code&gt;和&lt;code&gt;question_sql&lt;/code&gt;中分别检索返回&lt;code&gt;top_k&lt;/code&gt;个相关的内容，然后组成初始的&lt;code&gt;prompt&lt;/code&gt;，如果不需要生成中间&lt;code&gt;SQL&lt;/code&gt;，则直接提取回答中的&lt;code&gt;SQL&lt;/code&gt;并返回，若需要执行，则经过查询一些表格的内容信息并重新组成&lt;code&gt;prompt&lt;/code&gt;传给大模型进行处理，最后经过提取回答中的&lt;code&gt;SQL&lt;/code&gt;并返回。&lt;/p&gt;
&lt;p&gt;接下来就是打印结果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if print_results:
   try:
       Code = __import__(&quot;IPython.display&quot;, fromList=[&quot;Code&quot;]).Code
       display(Code(sql))
   except Exception as e:
       print(sql)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这段代码以语法高亮的方式显示生成的 SQL 查询，以美观的格式显示生成的 SQL 查询，而不是简单地打印纯文本。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if self.run_sql_is_set is False:
   print(
       &quot;If you want to run the SQL query, connect to a database first.&quot;
   )

   if print_results:
       return None
   else:
       return sql, None, None
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;最后就是执行SQL语句、绘制表格和数据图：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;if self.run_sql_is_set is False:
   print(
       &quot;If you want to run the SQL query, connect to a database first.&quot;
   )

   if print_results:
       return None
   else:
       return sql, None, None

try:
   df = self.run_sql(sql)

   if print_results:
       try:
           display = __import__(
               &quot;IPython.display&quot;, fromList=[&quot;display&quot;]
           ).display
           display(df)
       except Exception as e:
           print(df)

   if len(df) &gt; 0 and auto_train:
       self.add_question_sql(question=question, sql=sql)
   ## Only generate plotly code if visualize is True
   if visualize:
       try:
           plotly_code = self.generate_plotly_code(
               question=question,
               sql=sql,
               df_metadata=f&quot;Running df.dtypes gives:\n {df.dtypes}&quot;,
           )
           fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)
           if print_results:
               try:
                   display = __import__(
                       &quot;IPython.display&quot;, fromlist=[&quot;display&quot;]
                   ).display
                   Image = __import__(
                       &quot;IPython.display&quot;, fromlist=[&quot;Image&quot;]
                   ).Image
                   img_bytes = fig.to_image(format=&quot;png&quot;, scale=2)
                   display(Image(img_bytes))
               except Exception as e:
                   fig.show()
       except Exception as e:
           ## Print stack trace
           traceback.print_exc()
           print(&quot;Couldn&apos;t run plotly code: &quot;, e)
           if print_results:
               return None
           else:
               return sql, df, None
   else:
       return sql, df, None

except Exception as e:
   print(&quot;Couldn&apos;t run sql: &quot;, e)
   if print_results:
       return None
   else:
       return sql, None, None
return sql, df, fig
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;至此，Vanna的两大关键方法已经解读地差不多了。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/text2sql/text2sql_blogs-2/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/text2sql/text2sql_blogs-2/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>Vision-Language Models（VLM）学习（一）原理</title><link>https://astro-pure.js.org/blog/vlm_blogs/vlm_blog-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/vlm_blogs/vlm_blog-1</guid><description>记录VLM学习的内容。</description><pubDate>Sat, 31 Jan 2026 15:59:00 GMT</pubDate><content:encoded>&lt;p&gt;视觉大模型（VLM）是多模态人工智能模型，能够同时处理和理解图像和文本数据。与传统的单模态模型（如仅处理图像的卷积神经网络或仅处理文本的语言模型）相比，VLM通过学习视觉和语言之间的关联，能够处理更复杂的任务。例如，VLM可以根据图像生成描述性文本（图像描述），回答基于图像的问题（视觉问答），甚至根据文本生成图像（多模态生成）。&lt;/p&gt;
&lt;p&gt;VLM的出现得益于大规模预训练技术的发展。研究人员利用网络上几乎无限的图像-文本对数据（如网页抓取数据）进行预训练，使VLM能够在零样本（Zero-Shot）场景下完成多种视觉识别任务，而无需针对每个任务进行专门的训练。这种能力显著降低了模型开发的时间和成本。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么VLM重要？&lt;/strong&gt;
VLM的重要性体现在以下几个方面：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;多模态任务处理：VLM能够同时处理视觉和语言信息，适用于需要跨模态理解的任务，如图像描述、视觉问答和多模态推理。&lt;/li&gt;
&lt;li&gt;广泛的应用场景：VLM在医疗诊断（如分析医学影像并生成报告）、教育（生成交互式学习内容）、娱乐（生成创意图像或视频）和机器人（理解环境并执行语言指令）等领域具有巨大潜力。&lt;/li&gt;
&lt;li&gt;零样本学习能力：通过大规模预训练，VLM可以在未见过的数据上表现出色，减少了对标注数据的需求。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;例如，CLIP（2021年由OpenAI发布）通过学习图像-文本对，展示了强大的零样本图像分类能力，超越了许多专门训练的视觉模型。这种能力为后续VLM的发展奠定了基础。&lt;/p&gt;
&lt;p&gt;结合下面的介绍一下VLM模型的原理：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fvlm_blogs%2Fvlm_blog-1%2F1.png%3ForigWidth%3D1096%26origHeight%3D547%26origFormat%3Dpng&amp;#x26;w=1096&amp;#x26;h=547&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;strong&gt;核心思想：将图片使用&lt;code&gt;Visual encoder&lt;/code&gt;（一般都是&lt;code&gt;Vision Transformer&lt;/code&gt;）进行编码，然后与文本的编码特征进行拼接，再送到大模型里面进行预测后面的内容。&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;对于&lt;code&gt;Text&lt;/code&gt;，跟大模型的操作一样，首先经过&lt;code&gt;tokenizer&lt;/code&gt;生成&lt;code&gt;token IDs&lt;/code&gt;，注意token IDs的长度为&lt;code&gt;196+9&lt;/code&gt;，这个&lt;code&gt;196&lt;/code&gt;是占位符，用于后续替换成&lt;code&gt;Image&lt;/code&gt;的&lt;code&gt;feature&lt;/code&gt;，这个例子中文本被编码的长度是&lt;code&gt;9&lt;/code&gt;。然后对&lt;code&gt;token IDs&lt;/code&gt;做线性层映射，目的是映射到与LLM的输入维度&lt;code&gt;dim&lt;/code&gt;一样长。&lt;/p&gt;
&lt;p&gt;对于&lt;code&gt;Image&lt;/code&gt;，首先将一张图片分成块（例如&lt;code&gt;16*16&lt;/code&gt;），然后展开成一列&lt;code&gt;196&lt;/code&gt;维，每一个维度都映射到长度为&lt;code&gt;768&lt;/code&gt;的空间，就变成了&lt;code&gt;(196, 768)&lt;/code&gt;的向量，然后与文本的编码特征向量进行&lt;code&gt;concat&lt;/code&gt;操作（注意，这里不是赋值替换占位符，需要保留反向传播梯度图的可导性采用&lt;code&gt;concat&lt;/code&gt;），这样就组成了大模型的输入特征，然后进行预测后面的位置的概率分布（也就是答案）。&lt;/p&gt;
&lt;p&gt;对于&lt;code&gt;LLM&lt;/code&gt;来说，他只认识来自&lt;code&gt;Text&lt;/code&gt;的编码特征，因为这是之前预训练过的，并不认识来自&lt;code&gt;Image&lt;/code&gt;的编码特征。所以这个流程可以理解为，&lt;code&gt;Image&lt;/code&gt;编码特征是来辅助&lt;code&gt;Text&lt;/code&gt;编码特征进行预测的，在训练阶段，模型就会结合这些特征做预测，在这个过程中去理解图像和问题。需要注意的是，输入的数据包括（图像、问题和答案），在计算损失的时候，&lt;code&gt;Image&lt;/code&gt;部分是不需要进行计算&lt;code&gt;Loss&lt;/code&gt;的，只对文本计算&lt;code&gt;Loss&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;训练过程中，一般来说，视觉大模型（&lt;code&gt;Visual encoder&lt;/code&gt;）是已经预训练过的，以及大语言模型&lt;code&gt;LLM&lt;/code&gt;也是预训练过的，所以在某些阶段是可以冻结的。主要训练的是&lt;code&gt;project&lt;/code&gt;层，这个层给人的感受就是，将&lt;code&gt;Image Embedding&lt;/code&gt;映射成&lt;code&gt;Text Embedding&lt;/code&gt;。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/vlm_blogs/vlm_blog-1/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/vlm_blogs/vlm_blog-1/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>Vision-Language Models（VLM）学习（二）CLIP模型训练原理</title><link>https://astro-pure.js.org/blog/vlm_blogs/vlm_blog-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/vlm_blogs/vlm_blog-2</guid><description>记录VLM学习的内容。</description><pubDate>Sat, 31 Jan 2026 15:59:00 GMT</pubDate><content:encoded>&lt;p&gt;代码来源&lt;a href=&quot;https://www.bilibili.com/video/BV13K421v7Ar?spm_id_from=333.788.videopod.sections&amp;#x26;vd_source=52455a50a39ab9ee183496a6de048a09&quot;&gt;UP主&lt;/a&gt;，感谢大佬开源~&lt;/p&gt;
&lt;p&gt;CLIP（Contrastive Language-Image Pre-training）是OpenAI在2021年提出的突破性多模态模型，它通过对比学习的方式将图像和文本映射到同一个语义空间中。&lt;/p&gt;
&lt;p&gt;核心思想：CLIP同时训练图像编码器和文本编码器，让配对的图像-文本对在向量空间中尽可能接近，而不配对的图像-文本对尽可能远离。这种对比学习范式使得模型能够理解图像和文本之间的语义关联。&lt;/p&gt;
&lt;p&gt;多模态就是指不同领域的输入数据，比如文字、图片、语音、视频等。在传统方法中，每个领域都有一些经典的处理算法，比如用于处理文本的RNN、LSTM、Transformer，用于处理图像的各类卷积神经网络等，&lt;strong&gt;各领域间相对独立&lt;/strong&gt;。但是，&lt;strong&gt;人们总会遇到需要联合领域数据的时候&lt;/strong&gt;，比如给一张图片，输出一段关于这个图片的描述；或者给一段文字，输出一张符合文字描述的图片。而实现这一目标的难点在于：不同领域数据间的特征分布、特征信息是不一样的。因此多模态模型的总体目标就是：&lt;strong&gt;训练一个模型，一方面能统一特征表达，另一方面又能让不同模态特征间学到相关性。&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;下面介绍CLIP模型的训练：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fvlm_blogs%2Fvlm_blog-2%2F1.png%3ForigWidth%3D1131%26origHeight%3D406%26origFormat%3Dpng&amp;#x26;w=1131&amp;#x26;h=406&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
主要流程分为两步：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;使用对比学习进行预训练&lt;/li&gt;
&lt;li&gt;进行零样本预测&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;CLIP的训练数据是&amp;#x3C;图像，文本&gt;&lt;code&gt;pair&lt;/code&gt;。如图所示，一个&lt;code&gt;batch&lt;/code&gt;的数据里，有若干张图像，每张图像配有相应的文字描述信息（&lt;code&gt;Prompt&lt;/code&gt;），比如：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;一张小狗图片，Prompt为&amp;#x3C;dog&gt;，或者为&amp;#x3C;A photo of a dog&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;值得一提的是，&lt;code&gt;CLIP&lt;/code&gt;的作者发现，&lt;code&gt;Prompt&lt;/code&gt;的设计也会影响模型最终的效果。比如把&lt;code&gt;Prompt&lt;/code&gt;从单词&lt;code&gt;&amp;#x3C;dog&gt;&lt;/code&gt;换成句子&lt;code&gt;&amp;#x3C;A photo of a dog&gt;&lt;/code&gt;后，模型在&lt;code&gt;ImageNet&lt;/code&gt;分类任务上的准确率直接提高了&lt;code&gt;1.3%&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;&lt;code&gt;CLIP&lt;/code&gt;模型由两个主体部分组成：&lt;code&gt;Text Encoder&lt;/code&gt;和&lt;code&gt;Image Encoder&lt;/code&gt;：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;对于&lt;code&gt;Text Encoder&lt;/code&gt;，&lt;code&gt;CLIP&lt;/code&gt;借鉴的是&lt;code&gt;GPT2&lt;/code&gt;的架构。对于每条&lt;code&gt;Prompt&lt;/code&gt;，在进入&lt;code&gt;Text Encoder&lt;/code&gt;前，都会添加表示开始和结束的符号&lt;code&gt;[SOS]&lt;/code&gt;与&lt;code&gt;[EOS]&lt;/code&gt;。最终将最后一层&lt;code&gt;[EOS]&lt;/code&gt;位置的向量作为该&lt;code&gt;Prompt&lt;/code&gt;的特征表示向量，也就是图中所绘的$T_i$。&lt;/p&gt;
&lt;p&gt;举个例子：
输入 &quot;a photo of a cat&quot; 会变成：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;[SOS] a photo of a cat [EOS]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;“最后一层 [EOS] 位置的向量”是什么意思？&lt;/strong&gt;
经过 &lt;code&gt;Text Encoder&lt;/code&gt;（多层 &lt;code&gt;Transformer&lt;/code&gt;）处理后，每一层都会为每个 &lt;code&gt;token&lt;/code&gt; 输出一个向量。我们关注的是 最后一层（final layer） 的输出。在这个输出序列中，对应 &lt;code&gt;[EOS]&lt;/code&gt; 这个 &lt;code&gt;token&lt;/code&gt; 位置的那个向量，就被选作整个句子的“代表”&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;对于&lt;code&gt;Image Encoder&lt;/code&gt;，&lt;code&gt;CLIP&lt;/code&gt;则尝试过5种不同的ResNet架构和3种VIT架构，最终选用的是架构为&lt;code&gt;Large,patch_size=14&lt;/code&gt;的&lt;code&gt;ViT&lt;/code&gt;，同时在整个&lt;code&gt;CLIP&lt;/code&gt;预训练结束后，用更高分辨率（&lt;code&gt;336*336&lt;/code&gt;）的图片做了一个&lt;code&gt;epoch&lt;/code&gt;的&lt;code&gt;fine-tune&lt;/code&gt;，目的是让&lt;code&gt;CLIP&lt;/code&gt;能涌现出更好的效果。与&lt;code&gt;Text Encoder&lt;/code&gt;类似，每张图片对应一个最终特征表示向量$I_i$。&lt;/p&gt;
&lt;p&gt;&lt;code&gt;Vision Transformer（ViT）&lt;/code&gt;在处理图像时，会：
（1）将输入图像（如 &lt;code&gt;224×224&lt;/code&gt;）切分成若干个 &lt;code&gt;patch&lt;/code&gt;（例如 &lt;code&gt;patch_size=14&lt;/code&gt; → 每个 patch 是 &lt;code&gt;14×14&lt;/code&gt; 像素）。
（2）每个 &lt;code&gt;patch&lt;/code&gt; 被展平并通过一个线性层映射为一个嵌入向量。
（3）在 &lt;code&gt;patch&lt;/code&gt; 序列最前面额外添加一个可学习的特殊 &lt;code&gt;token&lt;/code&gt;，称为 &lt;code&gt;[class] token&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;所以输入序列是：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;[class], patch_1, patch_2, ..., patch_N
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;经过多层 &lt;code&gt;Transformer&lt;/code&gt; 编码器后，最后一层会输出与输入等长的向量序列。	其中，第一个位置（即 &lt;code&gt;[class] token&lt;/code&gt; 对应的位置）的输出向量，被认为聚合了整张图像的全局信息。	这个向量通常记作 &lt;code&gt;z_cls&lt;/code&gt;，维度为 &lt;code&gt;ViT&lt;/code&gt; 的隐藏层维度（例如 &lt;code&gt;ViT-L/14&lt;/code&gt; 的 &lt;code&gt;hidden size = 1024&lt;/code&gt;）。在&lt;code&gt;CLIP&lt;/code&gt;中会使用这个向量&lt;code&gt;z_cls&lt;/code&gt;，将其映射到一个共享的多模态嵌入空间。这个投影后的向量才被归一化并用作最终的图像特征表示 $I_i$。&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;接下来是第一步对比学习：&lt;/p&gt;
&lt;p&gt;假设一个&lt;code&gt;batch&lt;/code&gt;中共有&lt;code&gt;N&lt;/code&gt;对&amp;#x3C;图像，文字&gt;对，那么过完各自的&lt;code&gt;Encoder&lt;/code&gt;后，就会分别产生：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;N条文字向量[$T_1, T_2, ..., T_N$]&lt;/li&gt;
&lt;li&gt;N条图片向量[$I_1, I_2, ..., I_N$]&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这两组向量，将会分别过一次多模态&lt;code&gt;Embedding(Multimodal Embedding)&lt;/code&gt;，也就是在途中代表文字的紫色向量下，还有一层参数$W_t$（图中没有画出来），文字向量需要先和$W_t$做矩阵相乘后，才能得到最终文字向量。对图片向量，同理也有个对应的$W_i$。$W_t$和$W_i$的作用可以理解成把文字、图片特征投影到多模态的特征空间中去。&lt;/p&gt;
&lt;p&gt;经过多模态&lt;code&gt;Embedding&lt;/code&gt;的处理，我们得到了最终的[$T_1, T_2, ..., T_N$]和[$I_1, I_2, ..., I_N$]。接下来，模型就能**通过“对比学习”，找到图像和文字的相似关系。**做法也很简单，对于图中列出的&lt;code&gt;N*N&lt;/code&gt;个格子，只需要计算每个格子上对应的向量点积（余弦相似度）即可。对于对角线上的图片-文字对是&lt;code&gt;Ground Truth&lt;/code&gt;，希望对角线上的相似度可以最大，据此可以设置交叉熵函数，来球的每个&lt;code&gt;batch&lt;/code&gt;下的&lt;code&gt;Loss&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;需要注意的是，在计算&lt;code&gt;Loss&lt;/code&gt;时，会算两个&lt;code&gt;Loss&lt;/code&gt;再取平均，这是因为&lt;code&gt;CLIP&lt;/code&gt;按行计算&lt;code&gt;Loss&lt;/code&gt;和按列计算&lt;code&gt;Loss&lt;/code&gt;。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;按行计算&lt;code&gt;Loss&lt;/code&gt;，在每一行的范围内做&lt;code&gt;softmax&lt;/code&gt;，然后计算&lt;code&gt;cross_entropy&lt;/code&gt;（蓝色格子部分是&lt;code&gt;Ground Truth&lt;/code&gt;）。这样计算&lt;code&gt;Loss&lt;/code&gt;的意义是：对于每一张图片，我们都希望找到和它最相似的文字。&lt;/li&gt;
&lt;li&gt;按列计算&lt;code&gt;Loss&lt;/code&gt;，在每一列的范围内做&lt;code&gt;softmax&lt;/code&gt;，然后计算&lt;code&gt;cross_entropy&lt;/code&gt;（蓝色格子部分是&lt;code&gt;Ground Truth&lt;/code&gt;）。这样计算&lt;code&gt;Loss&lt;/code&gt;的意义是：对于每一张文字，我们都希望找到和它最相似的图片。&lt;/li&gt;
&lt;li&gt;最后将这两个&lt;code&gt;Loss&lt;/code&gt;相加取平均，代表在模型优化过程中考虑了“图片-&gt;文字”和“文字-&gt;图片”的双向关系。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;第二部就是用训练好的模型来做&lt;code&gt;zero-shot&lt;/code&gt;预测了，流程如下：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;首先，创建一个标签全集，如图中（2）所示，并得到每一个标签的特征向量&lt;/li&gt;
&lt;li&gt;然后，取一张图片，如图中（3）所示，过&lt;code&gt;Image Encoder&lt;/code&gt;后得到该图片的特征向量&lt;/li&gt;
&lt;li&gt;最后，计算图片向量和文字向量的相似度，取相似度最高的那条&lt;code&gt;label&lt;/code&gt;即可。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;值得注意的是，对于标签来说，&lt;code&gt;CLIP&lt;/code&gt;需要一个标签全集。也就是说，当喂给&lt;code&gt;CLIP&lt;/code&gt;一张图时，不管这张图片它是否有见过，&lt;code&gt;CLIP&lt;/code&gt;都不会生成一个全新的标签，而是去全集标签中找一个最相似的给你（这也是&lt;code&gt;CLIP&lt;/code&gt;的缺陷之一）。&lt;/p&gt;
&lt;p&gt;至此，&lt;code&gt;CLIP&lt;/code&gt;技术部分已经讲完。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;那么CLIP模型有什么缺陷呢？&lt;/strong&gt;
&lt;strong&gt;缺陷一&lt;/strong&gt;：Zero-shot能力很强，但不是最强的
根据实验结果，&lt;code&gt;CLIP&lt;/code&gt;从来没有用&lt;code&gt;ImageNet&lt;/code&gt;的数据训练过，但它在&lt;code&gt;ImageNet&lt;/code&gt;的预测效果可以达到&lt;code&gt;76.2%&lt;/code&gt;，和用&lt;code&gt;ImageNet&lt;/code&gt;做训练集的&lt;code&gt;ResNet50&lt;/code&gt;基本一致。但&lt;code&gt;ResNet50&lt;/code&gt;并不是在&lt;code&gt;ImageNet&lt;/code&gt;分类任务上表现最&lt;code&gt;SOTA&lt;/code&gt;的模型，例如&lt;code&gt;MAE&lt;/code&gt;之类在&lt;code&gt;ImageNet&lt;/code&gt;上可以达到80%+。虽然&lt;code&gt;CLIP&lt;/code&gt;同样具有涌现能力，即当模型变大时，模型的效果会更好，但是因为&lt;code&gt;CLIP&lt;/code&gt;训练昂贵的原因，为了提升预测百分点而需要的代价是巨大的。因此这也是&lt;code&gt;CLIP&lt;/code&gt;当前的限制之一。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;缺陷二&lt;/strong&gt;：CLIP无法处理更抽象的任务
抽象的任务是指：输出图片中物体的个数等需要一定逻辑思维推理的任务。在论文的实验中也有给出一些说明，下图中刻画了CLIP和ResNet在不同数据集任务上的表现情况。绿色表示CLIP表现更好的数据集，蓝色标线ResNet表现更好的数据集。注意到蓝色部分的DTD（纹理分类）和CLEVRCountS（给图中物体计数）这两个数据集，都是相对抽象的任务，在这方面CLIP表现明显不如ResNet。
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fvlm_blogs%2Fvlm_blog-2%2F2.png%3ForigWidth%3D520%26origHeight%3D550%26origFormat%3Dpng&amp;#x26;w=520&amp;#x26;h=550&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
&lt;strong&gt;缺陷三&lt;/strong&gt;：当测试数据集分布严重偏移时，CLIP也束手无策
虽然 CLIP 以 Zero-shot 标榜，但是当测试数据集分布相对于训练数据集分布存在严重偏移情况时，CLIP 的表现也不理想。论文中提出了一个很有代表性的例子：MNIST（手写数字数据集）。这样一个简单的数据集，可能由 CV/M 都能做到 90% 以上的准确率了，但 CLIP 在上面的表现只有 88%。原因就在于 CLIP 的训练数据主要来自互联网图像和文本对（如网页上的图片和标题），其风格和内容偏向于自然场景、日常物体等，而 MNIST 是一种高度结构化、低分辨率、黑白线条构成的手写数字图像，与 CLIP 训练数据的视觉风格差异巨大。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;缺陷四&lt;/strong&gt;：文字标签是个闭集
前文说过，在对 CLIP 做 zero-shot 预测时，我们的文字标签是一个闭集（closed set），模型输入一张可能从未见过的图片，然后从这个预定义的标签集合中找出最匹配的一个，而不是去预测一个全新的文字标签。从这一点来看，CLIP 依然不够自动化。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;缺陷五&lt;/strong&gt;：受限于计算资源，无法做图像-文本的生成式网络
这个在 CLIP 看来是缺陷的问题，不久之后已经被我们熟知的 DALL·E 2 和 Stable Diffusion 所解决（没错，正是站在 CLIP 的肩膀上）。因此，这是 CLIP 的一个限制，但同时也为后续研究提供了重要的启发点。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/vlm_blogs/vlm_blog-2/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/vlm_blogs/vlm_blog-2/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>Vision-Language Models（VLM）学习（三）复现OpenAI的CLIP模型</title><link>https://astro-pure.js.org/blog/vlm_blogs/vlm_blog-3</link><guid isPermaLink="true">https://astro-pure.js.org/blog/vlm_blogs/vlm_blog-3</guid><description>记录VLM学习的内容。</description><pubDate>Sat, 31 Jan 2026 15:59:00 GMT</pubDate><content:encoded>&lt;p&gt;前面已经学习了CLIP模型的原理，本节就基于MNIST手写数字数据集实现一个CLIP模型。&lt;/p&gt;
&lt;p&gt;首先实现&lt;code&gt;Images Encoder&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=stride)
        self.bn1=nn.BatchNorm2d(out_channels)
        
        self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1)
        self.bn2=nn.BatchNorm2d(out_channels)
        
        self.conv3=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,padding=0,stride=stride)
    
    def forward(self,x):
        y=F.relu(self.bn1(self.conv1(x)))
        y=self.bn2(self.conv2(y))
        z=self.conv3(x)
        return F.relu(y+z)
        

class ImgEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.res_block1=ResidualBlock(in_channels=1,out_channels=16,stride=2) # (batch,16,14,14)
        self.res_block2=ResidualBlock(in_channels=16,out_channels=4,stride=2) # (batch,4,7,7)
        self.res_block3=ResidualBlock(in_channels=4,out_channels=1,stride=2) # (batch,1,4,4)
        self.wi=nn.Linear(in_features=16,out_features=8)
        self.ln=nn.LayerNorm(8)
        
    def forward(self,x):
        x=self.res_block1(x)
        x=self.res_block2(x)
        x=self.res_block3(x)
        x=self.wi(x.view(x.size(0),-1))
        x=self.ln(x)
        return x
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;Images Encoder&lt;/code&gt;是由一个&lt;code&gt;ResNet&lt;/code&gt;网络实现的。经过卷积操作后，通过&lt;code&gt;x=self.wi(x.view(x.size(0),-1))&lt;/code&gt;将特征展平，然后用一个&lt;code&gt;Linear&lt;/code&gt;层，将维度映射到&lt;code&gt;8&lt;/code&gt;。&lt;/p&gt;
&lt;p&gt;接下来是&lt;code&gt;Text Encoder&lt;/code&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb=nn.Embedding(num_embeddings=10,embedding_dim=16)
        self.dense1=nn.Linear(in_features=16,out_features=64)
        self.dense2=nn.Linear(in_features=64,out_features=16)
        self.wt=nn.Linear(in_features=16,out_features=8)
        self.ln=nn.LayerNorm(8)
    
    def forward(self,x):
        x=self.emb(x)
        x=F.relu(self.dense1(x))
        x=F.relu(self.dense2(x))
        x=self.wt(x)
        x=self.ln(x)
        return x
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;code&gt;Text Encoder&lt;/code&gt;是由几个&lt;code&gt;Linear&lt;/code&gt;构成的，首先输入的就是一个长度为&lt;code&gt;10&lt;/code&gt;的&lt;code&gt;one-hot&lt;/code&gt;向量，经过编码得到特征，也是一个长度为&lt;code&gt;8&lt;/code&gt;的特征向量。&lt;/p&gt;
&lt;p&gt;接下来就是&lt;code&gt;CLIP&lt;/code&gt;代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;class CLIP(nn.Module):
    def __init__(self,):
        super().__init__()
        self.img_enc=ImgEncoder()
        self.text_enc=TextEncoder()

    def forward(self,img_x,text_x):
        img_emb=self.img_enc(img_x)
        text_emb=self.text_enc(text_x)
        return img_emb@text_emb.T
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里就直接调用&lt;code&gt;Encoder&lt;/code&gt;进行编码，然后进行矩阵乘法操作。&lt;/p&gt;
&lt;p&gt;接下来就是训练代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;DEVICE=&apos;cuda&apos; if torch.cuda.is_available() else &apos;cpu&apos;   # 设备

dataset=MNIST() # 数据集

model=CLIP().to(DEVICE) # 模型

try:    # 加载模型
    model.load_state_dict(torch.load(&apos;model.pth&apos;))
except:
    pass 

optimzer=torch.optim.Adam(model.parameters(),lr=1e-3)   # 优化器

&apos;&apos;&apos;
    训练模型
&apos;&apos;&apos;
ITER_BATCH_COUNT=100000    # 迭代次数
BATCH_SIZE=64   # 从batch内选出10个不一样的数字
TARGET_COUNT=10 # 共10种数字

dataloader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=10,persistent_workers=True)    # 数据加载器

for i in range(ITER_BATCH_COUNT):
    while True:
        imgs,labels=next(iter(dataloader))
        if torch.unique(labels).shape[0]&amp;#x3C;TARGET_COUNT:  # 未覆盖10种数字
            continue
        # 挑选出10个数字
        target=set()    
        indexes=[]
        for j in range(BATCH_SIZE):
            if labels[j].item() in target:
                continue 
            target.add(labels[j].item())
            indexes.append(j)
            if len(target)==TARGET_COUNT:
                break
        imgs=imgs[indexes]
        labels=labels[indexes]
        break

    logits=model(imgs.to(DEVICE),labels.to(DEVICE))
    
    targets=torch.arange(0,TARGET_COUNT).to(DEVICE)
    loss_i=F.cross_entropy(logits,targets)
    loss_t=F.cross_entropy(logits.permute(1,0),targets)
    loss=(loss_i+loss_t)/2
    
    optimzer.zero_grad()
    loss.backward()
    optimzer.step()
    if i%1000==0:
        print(&apos;iter:{},loss:{}&apos;.format(i,loss))
        torch.save(model.state_dict(),&apos;.model.pth&apos;)
        os.replace(&apos;.model.pth&apos;,&apos;model.pth&apos;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;值的注意的是，在训练的过程中，每一个批次的样本都要包括10个不同手写数字的样本，这是因为我们在做CLIP训练时，如果同一个批次出现两张同类别的手写数字图，就会出现下面的情况：
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Fvlm_blogs%2Fvlm_blog-3%2F1.png%3ForigWidth%3D694%26origHeight%3D461%26origFormat%3Dpng&amp;#x26;w=694&amp;#x26;h=461&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
假设$T_1$和$T_N$是同一个标签，假设是&lt;code&gt;9&lt;/code&gt;，而$I_1$是数字&lt;code&gt;9&lt;/code&gt;对应的图片，此时在计算Loss时，会强制认为$T_1$和$I_1$是匹配的，从而打压模型对$T_N$和$I_1$的判断（事实上$T_N$和$I_1$也是匹配的，但是这时模型强制认为$T_1$和$I_1$是匹配的），所以会导致模型“脑裂”。&lt;/p&gt;
&lt;p&gt;所以训练代码中的：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;while True:
        imgs,labels=next(iter(dataloader))
        if torch.unique(labels).shape[0]&amp;#x3C;TARGET_COUNT:  # 未覆盖10种数字
            continue
        # 挑选出10个数字
        target=set()    
        indexes=[]
        for j in range(BATCH_SIZE):
            if labels[j].item() in target:
                continue 
            target.add(labels[j].item())
            indexes.append(j)
            if len(target)==TARGET_COUNT:
                break
        imgs=imgs[indexes]
        labels=labels[indexes]
        break
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;是来保证每轮训练的 &lt;code&gt;batch&lt;/code&gt; 必须包含所有10个类别。&lt;/p&gt;
&lt;p&gt;最后就是推理代码：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;DEVICE=&apos;cuda&apos; if torch.cuda.is_available() else &apos;cpu&apos;   # 设备

dataset=MNIST() # 数据集

model=CLIP().to(DEVICE) # 模型
model.load_state_dict(torch.load(&apos;model.pth&apos;))

model.eval()    # 预测模式

&apos;&apos;&apos;
1、对图片分类
&apos;&apos;&apos;
image,label=dataset[0]
print(&apos;正确分类:&apos;,label)
plt.imshow(image.permute(1,2,0))
plt.show()

targets=torch.arange(0,10)  #10种分类
logits=model(image.unsqueeze(0).to(DEVICE),targets.to(DEVICE)) # 1张图片 vs 10种分类
print(logits)
print(&apos;CLIP分类:&apos;,logits.argmax(-1).item())

&apos;&apos;&apos;
2、图像相似度
&apos;&apos;&apos;
other_images=[]
other_labels=[]
for i in range(1,101):
    other_image,other_label=dataset[i]
    other_images.append(other_image)
    other_labels.append(other_label)

# 其他100张图片的向量
other_img_embs=model.img_enc(torch.stack(other_images,dim=0).to(DEVICE))

# 当前图片的向量
img_emb=model.img_enc(image.unsqueeze(0).to(DEVICE))

# 计算当前图片和100张其他图片的相似度
logtis=img_emb@other_img_embs.T
values,indexs=logtis[0].topk(5) # 5个最相似的

plt.figure(figsize=(15,15))
for i,img_idx in enumerate(indexs):
    plt.subplot(1,5,i+1)
    plt.imshow(other_images[img_idx].permute(1,2,0))
    plt.title(other_labels[img_idx])
    plt.axis(&apos;off&apos;)
plt.show()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这里是实现了两种任务，第一种就是分类任务，第二种就是通过图片相似度实现以图搜图的任务。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/vlm_blogs/vlm_blog-3/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/vlm_blogs/vlm_blog-3/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>RAG实战（一）构建QA系统</title><link>https://astro-pure.js.org/blog/rag_blogs/rag_blogs-1</link><guid isPermaLink="true">https://astro-pure.js.org/blog/rag_blogs/rag_blogs-1</guid><description>记录RAG的学习。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/RAG_Learning&quot;&gt;Github地址&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;构建本地Q&amp;#x26;A系统&lt;/h2&gt;
&lt;p&gt;本系列代码都来自&lt;a href=&quot;https://rag.deeptoai.com/docs/rag-charts-reference&quot;&gt;博客&lt;/a&gt;，在博主的基础上实现了本地调用模型，同时基本跑通所有代码。实现了一个从本地构建知识库，并进行问答的系统。&lt;/p&gt;
&lt;p&gt;这个实战可以帮助理解RAG从文本分块-构建向量数据库-检索-生成的整个流程以及优化技巧等。&lt;/p&gt;
&lt;h3&gt;加载文档&lt;/h3&gt;
&lt;p&gt;在构建 RAG（Retrieval-Augmented Generation）系统时，第一步是将外部知识源（如网页、PDF、Word 文档等）加载为程序可处理的格式。LangChain 提供了丰富的 &lt;strong&gt;Document Loaders&lt;/strong&gt;，用于从不同来源提取文本内容，并统一封装为 &lt;code&gt;Document&lt;/code&gt; 对象。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;什么是 Document？&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在 LangChain 中，&lt;code&gt;Document&lt;/code&gt; 是一个轻量级的数据结构，用于表示一段文本及其相关的元信息（metadata）。其基本结构如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_core.documents import Document

doc = Document(
    page_content=&quot;这是文档的正文内容。&quot;,
    metadata={&quot;source&quot;: &quot;example.pdf&quot;, &quot;page&quot;: 1, &quot;title&quot;: &quot;示例标题&quot;}
)
&lt;/code&gt;&lt;/pre&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;page_content&lt;/code&gt;：字符串类型，存储实际的文本内容。&lt;/li&gt;
&lt;li&gt;&lt;code&gt;metadata&lt;/code&gt;：字典类型，包含与该文档片段相关的附加信息，例如来源 URL、页码、标题、作者、创建时间等。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这种设计使得后续的文本分割、向量化、检索等步骤能够保留上下文和溯源信息，对于构建可解释、可追踪的 RAG 系统至关重要。&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;加载 Web 文档&lt;/h3&gt;
&lt;p&gt;我们使用 &lt;code&gt;WebBaseLoader&lt;/code&gt; 从指定的 CSDN 博客链接加载内容。该 loader 基于 &lt;code&gt;requests&lt;/code&gt; 和 &lt;code&gt;BeautifulSoup&lt;/code&gt;，能自动解析 HTML 并提取正文文本。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_community.document_loaders import WebBaseLoader

loader = WebBaseLoader(&quot;https://blog.csdn.net/weixin_44919384/article/details/154616759?spm=1001.2014.3001.5501&quot;)
docs = loader.load()

print(f&quot;加载了 {len(docs)} 个文档&quot;)
title = docs[0].metadata.get(&apos;title&apos;, &apos;N/A&apos;)
print(f&quot;第一个文档的标题：{title}&quot;)
print(f&quot;第一个文档长度: {len(docs[0].page_content)} 字符&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;加载了 1 个文档
第一个文档的标题：模型训练（四）梯度累计Gradient Accumulation-CSDN博客
第一个文档长度: 11642 字符
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;提示&lt;/strong&gt;：&lt;code&gt;WebBaseLoader&lt;/code&gt; 默认会尝试提取页面的 &lt;code&gt;&amp;#x3C;title&gt;&lt;/code&gt; 和部分元标签作为 &lt;code&gt;metadata&lt;/code&gt;，但具体效果取决于目标网站的 HTML 结构。对于复杂或动态渲染的网页，可能需要结合 &lt;code&gt;Selenium&lt;/code&gt; 或 &lt;code&gt;Playwright&lt;/code&gt; 等工具。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;加载 PDF 文档&lt;/h3&gt;
&lt;p&gt;对于本地 PDF 文件，我们使用 &lt;code&gt;PyPDFLoader&lt;/code&gt;。它基于 &lt;code&gt;pypdf&lt;/code&gt; 库，按页读取 PDF 内容，每一页会被封装为一个独立的 &lt;code&gt;Document&lt;/code&gt; 对象。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader(&quot;./Dataset/PDF/基于视-触觉融合感知的机器人抓取滑动检测与力控研究_闫腾.pdf&quot;)
docs = loader.load()

first_doc = docs[0]
print(&quot;meta_data:&quot;, first_doc.metadata)
print(&quot;content:&quot;, first_doc.page_content[:500] + &quot;...&quot; if len(first_doc.page_content) &gt; 500 else first_doc.page_content)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;meta_data: {&apos;producer&apos;: &apos;TTKN&apos;, &apos;creator&apos;: &apos;ReaderEx_DIS 2.5.0 Build 4088&apos;, &apos;creationdate&apos;: &apos;2025-11-03T14:24:27-08:00&apos;, &apos;author&apos;: &apos;CNKI&apos;, &apos;source&apos;: &apos;./Dataset/PDF/基于视-触觉融合感知的机器人抓取滑动检测与力控研究_闫腾.pdf&apos;, &apos;total_pages&apos;: 79, &apos;page&apos;: 0, &apos;page_label&apos;: &apos;1&apos;}
content: 硕士学位论 文
学位申请人姓名 闫腾
学位申请人学号 2200411007
专 业 名 称 机械工程
学 科 门 类 工学
学院（部、研究院） 应用技术学院
导 师 姓 名 李文贤
二〇二五年六月
分类号 学校代码 10590
UDC 密 级 公开
基于视-触觉融合感知的机器人
抓取滑动检测与力控研究
ധᎪն࿐
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;📌 &lt;strong&gt;注意&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;PyPDFLoader&lt;/code&gt; 对扫描版 PDF（即图片型 PDF）无效，仅适用于文字可复制的 PDF。&lt;/li&gt;
&lt;li&gt;每个 &lt;code&gt;Document&lt;/code&gt; 的 &lt;code&gt;metadata&lt;/code&gt; 中通常包含 &lt;code&gt;&quot;source&quot;&lt;/code&gt;（文件路径）和 &lt;code&gt;&quot;page&quot;&lt;/code&gt;（页码），便于后续定位原文位置。&lt;/li&gt;
&lt;/ul&gt;
&lt;/blockquote&gt;
&lt;p&gt;通过上述步骤，我们成功将异构数据源统一转换为 LangChain 的 &lt;code&gt;Document&lt;/code&gt; 格式，为下一步的&lt;strong&gt;文本分割&lt;/strong&gt;和&lt;strong&gt;向量嵌入&lt;/strong&gt;做好了准备。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part1: 文本分块（Text Chunking）&lt;/h2&gt;
&lt;p&gt;在将原始文档加载为 &lt;code&gt;Document&lt;/code&gt; 对象后，下一步是将其切分为更小的“文本块”（&lt;code&gt;chunks&lt;/code&gt;）。这一步对 RAG 系统的性能和效果至关重要。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么需要分块？&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;LLM 上下文长度限制&lt;/strong&gt;&lt;br&gt;
当前主流大语言模型（如 GPT-4、Claude、Llama 等）都有最大 token 输入限制（例如 8K、32K 或 128K）。若直接将整篇长文档送入模型，会超出上下文窗口，导致截断或报错。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;提升检索精度&lt;/strong&gt;&lt;br&gt;
向量数据库在检索时，会将用户查询与每个文本块的嵌入向量进行相似度匹配。较小且语义完整的文本块更容易与特定问题对齐，避免无关信息干扰。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;降低计算与推理成本&lt;/strong&gt;&lt;br&gt;
RAG 只需将最相关的几个文本块送入 LLM 进行生成，而非整篇文档。这显著减少了 token 消耗和响应延迟，尤其在调用付费 API 时能有效控制成本。&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;hr&gt;
&lt;h3&gt;1.1 创建文本分块器&lt;/h3&gt;
&lt;p&gt;LangChain 提供了多种文本分块策略。我们首先使用推荐的 &lt;strong&gt;递归字符分块器（RecursiveCharacterTextSplitter）&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建文本分块器
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,        # 每块最大字符数
    chunk_overlap=200,      # 块之间的重叠
    length_function=len,    # 长度计算函数
    is_separator_regex=False,
)

# 分块文档
splits = text_splitter.split_documents(docs)

print(f&quot;原始文档: {len(docs)} 个&quot;)
print(f&quot;分块后: {len(splits)} 个&quot;)
print(f&quot;\n第一个分块示例:\n{splits[0].page_content[:200]}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出示例：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;原始文档: 79 个
分块后: 99 个

第一个分块示例:
硕士学位论 文
学位申请人姓名 闫腾
学位申请人学号 2200411007
专 业 名 称 机械工程
学 科 门 类 工学
学院（部、研究院） 应用技术学院
导 师 姓 名 李文贤
二〇二五年六月
分类号 学校代码 10590
UDC 密 级 公开
基于视-触觉融合感知的机器人
抓取滑动检测与力控研究
ധᎪն࿐...
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;chunk_overlap 的作用&lt;/strong&gt;：通过保留相邻块的部分重叠内容，可减少因切分导致的关键信息丢失（例如一个句子被切成两半），提升后续检索与生成的连贯性。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;1.2 分块策略对比&lt;/h3&gt;
&lt;p&gt;LangChain 支持多种分块方式，适用于不同场景。下面我们对比四种常见策略：&lt;/p&gt;
&lt;h4&gt;1. 字符分块（CharacterTextSplitter）&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;原理&lt;/strong&gt;：按固定字符数硬切分。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;优点&lt;/strong&gt;：实现简单、速度快。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;缺点&lt;/strong&gt;：极易切断句子或段落，破坏语义完整性。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;适用场景&lt;/strong&gt;：对语义要求不高的粗略处理。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_text_splitters import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;2. 递归分块（RecursiveCharacterTextSplitter）✅ &lt;strong&gt;推荐默认&lt;/strong&gt;&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;原理&lt;/strong&gt;：按优先级尝试在语义边界（如 &lt;code&gt;\n\n&lt;/code&gt; → &lt;code&gt;\n&lt;/code&gt; → 空格 → 任意字符）处分割，尽可能保持段落完整。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;优点&lt;/strong&gt;：兼顾效率与语义，适用于大多数文本（论文、网页、报告等）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;缺点&lt;/strong&gt;：仍基于字符长度，无法精确控制 token 数。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;适用场景&lt;/strong&gt;：通用 RAG 项目首选。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200,
    separators=[&quot;\n\n&quot;, &quot;\n&quot;, &quot; &quot;, &quot;&quot;]
)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;3. Token 分块（TokenTextSplitter）&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;原理&lt;/strong&gt;：使用指定 tokenizer（如 GPT 的 tiktoken）精确按 token 切分。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;优点&lt;/strong&gt;：严格控制输入长度，避免超限。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;缺点&lt;/strong&gt;：依赖具体模型的 tokenizer，速度较慢；仍可能切断句子。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;适用场景&lt;/strong&gt;：对接特定 LLM 且对 token 预算敏感的任务。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_text_splitters import TokenTextSplitter
text_splitter = TokenTextSplitter(chunk_size=256, chunk_overlap=50)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;4. 语义分块（SemanticChunker）&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;原理&lt;/strong&gt;：利用嵌入模型计算句子间语义相似度，在“语义突变点”处分割。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;优点&lt;/strong&gt;：块内语义高度一致，检索质量高。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;缺点&lt;/strong&gt;：计算开销大，需调用 embedding 模型；分块大小不固定。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;适用场景&lt;/strong&gt;：对检索精度要求极高的专业领域（如法律、医疗）。&lt;/li&gt;
&lt;/ul&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
text_splitter = SemanticChunker(embeddings)
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;1.3 分块策略对比总结&lt;/h3&gt;
&lt;p&gt;| 策略 | 优点 | 缺点 | 适用场景 |
|------|------|------|----------|
| &lt;strong&gt;字符分块&lt;/strong&gt; | 简单、快速 | 易切断语义，质量差 | 快速原型、非关键任务 |
| &lt;strong&gt;递归分块&lt;/strong&gt; ✅ | 语义友好、高效、通用 | 基于字符，非 token 精确 | &lt;strong&gt;大多数 RAG 项目推荐&lt;/strong&gt; |
| &lt;strong&gt;Token 分块&lt;/strong&gt; | 精确控制 token 长度 | 依赖 tokenizer，可能断句 | 对接特定 LLM，严格 token 限制 |
| &lt;strong&gt;语义分块&lt;/strong&gt; | 语义连贯性最佳 | 计算开销大，速度慢 | 高精度检索（法律、科研等） |&lt;/p&gt;
&lt;p&gt;接下来，我们将把分块后的文本转换为向量，并存入向量数据库，为检索阶段做准备。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part2: 向量化：将文本转化为语义向量&lt;/h2&gt;
&lt;p&gt;在 RAG 系统中，&lt;strong&gt;向量化（Embedding）&lt;/strong&gt; 是实现语义检索的核心步骤。其目标是将分块后的文本转换为高维数学向量，使得语义相近的文本在向量空间中距离更近，从而支持高效的相似性搜索。&lt;/p&gt;
&lt;h3&gt;2.1 为什么需要向量化？&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;超越关键词匹配&lt;/strong&gt;：传统关键词检索无法理解“苹果”和“水果”的语义关系，而向量嵌入能捕捉深层语义。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;支持语义搜索&lt;/strong&gt;：通过计算查询与文档块之间的向量相似度（如余弦相似度），可返回最相关的内容，即使措辞不同。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;为向量数据库提供输入&lt;/strong&gt;：后续我们将这些向量存入 FAISS、Chroma、Milvus 等向量数据库，实现毫秒级检索。&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;h3&gt;2.2 常用嵌入模型对比&lt;/h3&gt;
&lt;p&gt;选择合适的嵌入模型对 RAG 效果至关重要。以下是主流模型的横向对比：&lt;/p&gt;
&lt;p&gt;| 模型 | 提供商 | 维度 | 成本 | 性能 | 特点 |
|------|--------|------|------|------|------|
| &lt;code&gt;text-embedding-3-small&lt;/code&gt; | OpenAI | 1536 | $ | ⭐ 高性价比 | 通用场景首选，速度快、效果好 |
| &lt;code&gt;text-embedding-3-large&lt;/code&gt; | OpenAI | 3072 | $$ | 最高质量 | 适合高精度要求任务 |
| &lt;code&gt;text-embedding-ada-002&lt;/code&gt; | OpenAI | 1536 | $ | 上一代 | 兼容旧系统，逐渐被 small 替代 |
| &lt;code&gt;all-MiniLM-L6-v2&lt;/code&gt; | Hugging Face | 384 | 免费 | 轻量快速 | 适合本地部署，英文为主 |
| &lt;code&gt;bce-embedding-base_v1&lt;/code&gt; | 百度 / ModelScope | 768 | 免费 | ⭐ 中文优化 | &lt;strong&gt;中文任务表现优异，推荐本地使用&lt;/strong&gt; |
| &lt;code&gt;bce-reranker-base_v1&lt;/code&gt; | 百度 | - | 免费 | 重排序专用 | 用于检索后精排，非嵌入模型 |&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;中文项目建议&lt;/strong&gt;：&lt;br&gt;
若你的数据以中文为主（如学位论文、技术文档），&lt;strong&gt;&lt;code&gt;bce-embedding-base_v1&lt;/code&gt;&lt;/strong&gt; 是目前开源免费模型中表现最出色的之一，专为中文语义理解优化，且支持本地 GPU 加速。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;2.3 使用 BCE 嵌入模型进行本地向量化&lt;/h3&gt;
&lt;p&gt;我们通过 ModelScope 下载百度开源的 &lt;code&gt;bce-embedding-base_v1&lt;/code&gt; 模型，并使用 LangChain 封装调用：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_community.embeddings import HuggingFaceEmbeddings
from modelscope import snapshot_download
import os

# 创建本地模型目录
local_models_dir = &quot;./Models&quot;
os.makedirs(local_models_dir, exist_ok=True)

# 从 ModelScope 下载模型
model_id = &quot;maidalun/bce-embedding-base_v1&quot;
local_model_path = snapshot_download(model_id, cache_dir=local_models_dir)

# 初始化嵌入模型（启用 GPU）
embeddings = HuggingFaceEmbeddings(
    model_name=local_model_path,
    model_kwargs={&quot;device&quot;: &quot;cuda&quot;},               # 使用 GPU 加速
    encode_kwargs={&quot;normalize_embeddings&quot;: True}   # 归一化便于计算余弦相似度
)

# 测试单条文本向量化
text = &quot;RAG是一种强大的AI技术&quot;
vector = embeddings.embed_query(text)

print(f&quot;文本: {text}&quot;)
print(f&quot;向量维度: {len(vector)}&quot;)
print(f&quot;向量前5个值: {[round(x, 4) for x in vector[:5]]}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;输出结果&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;文本: RAG是一种强大的AI技术
向量维度: 768
向量前5个值: [0.0048, 0.0216, -0.005, 0.02, -0.0113]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;✅ 成功生成 768 维的语义向量！&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;2.4 向量相似度验证：语义是否被正确捕捉？&lt;/h3&gt;
&lt;p&gt;我们通过余弦相似度验证模型的语义理解能力：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

text1 = &quot;苹果是一种水果&quot;
text2 = &quot;香蕉是一种水果&quot;
text3 = &quot;苹果是一种好吃的水果&quot;

# 获取向量
v1 = np.array(embeddings.embed_query(text1)).reshape(1, -1)
v2 = np.array(embeddings.embed_query(text2)).reshape(1, -1)
v3 = np.array(embeddings.embed_query(text3)).reshape(1, -1)

# 计算相似度
sim_1_2 = cosine_similarity(v1, v2)[0][0]
sim_1_3 = cosine_similarity(v1, v3)[0][0]

print(f&quot;「{text1}」 vs 「{text2}」 相似度: {sim_1_2:.4f}&quot;)
print(f&quot;「{text1}」 vs 「{text3}」 相似度: {sim_1_3:.4f}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;输出结果&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;「苹果是一种水果」 vs 「香蕉是一种水果」 相似度: 0.7481
「苹果是一种水果」 vs 「苹果是一种好吃的水果」 相似度: 0.9040
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;🔍 &lt;strong&gt;分析&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;两者都提到“水果”，因此相似度较高（&gt;0.7）；&lt;/li&gt;
&lt;li&gt;text1 与 text3 主体完全一致（“苹果”），仅增加形容词“好吃”，语义更接近，相似度达 &lt;strong&gt;0.904&lt;/strong&gt;，说明模型有效捕捉了中文语义细微差别。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;下一步：构建向量数据库
现在，我们已将每个文本块转换为 768 维向量。接下来，我们将把这些向量与原始文本一起存入 &lt;strong&gt;向量数据库&lt;/strong&gt;（如 FAISS 或 Chroma），为用户查询提供高效、精准的语义检索能力。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part3: 存储到向量数据库&lt;/h2&gt;
&lt;p&gt;在完成文档加载、分块和向量化后，下一步是将这些文本块及其对应的嵌入向量持久化存储到&lt;strong&gt;向量数据库&lt;/strong&gt;中。这是 RAG 系统实现高效语义检索的关键基础设施。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么需要向量数据库？&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;快速相似性搜索&lt;/strong&gt;：面对成千上万的文本块，暴力计算余弦相似度效率极低。向量数据库通过近似最近邻（ANN）算法（如 HNSW、IVF）实现毫秒级检索。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;元数据关联&lt;/strong&gt;：除了向量，还能存储原始文本、来源文件、页码等 metadata，便于溯源和结果展示。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;持久化与复用&lt;/strong&gt;：一次构建，多次查询，避免重复加载和嵌入计算。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;3.1 主流向量数据库对比&lt;/h3&gt;
&lt;p&gt;| 数据库 | 类型 | 优势 | 适用场景 |
|--------|------|------|----------|
| &lt;strong&gt;Chroma&lt;/strong&gt; | 嵌入式 | 简单易用，无需额外服务，LangChain 深度集成 | ✅ 开发、小规模应用、本地原型 |
| Pinecone | 云服务 | 高性能、全托管、自动扩缩容 | 生产环境（需网络 &amp;#x26; 账号） |
| Weaviate | 自建/云 | 功能丰富（支持 GraphQL、分类、生成），开源 | 大规模部署、企业级应用 |
| FAISS | 库（非数据库） | 速度快，Meta 开源，适合研究 | 离线实验、临时索引 |&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;🚀 &lt;strong&gt;本项目选择 Chroma&lt;/strong&gt;：因其轻量、本地运行、与 LangChain 无缝对接，非常适合学术论文类 RAG 应用的开发与测试。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;3.2 批量加载目录下所有 PDF 并预处理&lt;/h3&gt;
&lt;p&gt;我们的目标是将 &lt;code&gt;./Dataset/PDF/&lt;/code&gt; 目录下的 &lt;strong&gt;10 篇机器人抓取相关硕士论文&lt;/strong&gt;统一加载、分块、清洗并存入向量库。&lt;/p&gt;
&lt;h4&gt;步骤 1：递归加载所有 PDF&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

def load_all_pdfs_from_directory(directory_path=&quot;./Dataset/PDF/&quot;):
    &quot;&quot;&quot;加载目录下所有PDF文件，并为每个分块添加唯一ID&quot;&quot;&quot;
    all_splits = []
    pdf_files = []
    
    if not os.path.exists(directory_path):
        print(f&quot;❌ 目录不存在: {directory_path}&quot;)
        return all_splits, pdf_files
    
    # 获取所有 PDF 文件
    for filename in os.listdir(directory_path):
        if filename.lower().endswith(&apos;.pdf&apos;):
            pdf_files.append(filename)
    
    if not pdf_files:
        print(f&quot;⚠️  目录中没有找到PDF文件: {directory_path}&quot;)
        return all_splits, pdf_files
    
    print(f&quot;📁 找到 {len(pdf_files)} 个PDF文件:&quot;)
    
    # 创建分块器（注意：此处 chunk_size=100 是示例，实际可调大）
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=100,
        chunk_overlap=20,
        length_function=len,
    )
    
    global_chunk_index = 0
    for filename in pdf_files:
        file_path = os.path.join(directory_path, filename)
        base_name = os.path.splitext(filename)[0]
        
        try:
            print(f&quot;📖 正在加载: {filename}&quot;)
            loader = PyPDFLoader(file_path)
            docs = loader.load()
            splits = text_splitter.split_documents(docs)
            
            for i, split in enumerate(splits):
                page_num = split.metadata.get(&apos;page&apos;, &apos;unknown&apos;)
                split_id = f&quot;{base_name}_p{page_num}_c{i}&quot;
                split.metadata.update({
                    &apos;id&apos;: split_id,
                    &apos;source_file&apos;: filename
                })
                all_splits.append(split)
            
            print(f&quot;✅ {filename}: {len(docs)}页 -&gt; {len(splits)}个分块&quot;)
        except Exception as e:
            print(f&quot;❌ 加载失败 {filename}: {str(e)}&quot;)
    
    print(f&quot;🎉 总共加载 {len(all_splits)} 个文本分块（均已添加 id）&quot;)
    return all_splits, pdf_files

# 执行加载
splits, loaded_files = load_all_pdfs_from_directory(&quot;./Dataset/PDF/&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📁 找到 10 个PDF文件:
✅ ...（略）
🎉 总共加载 12729 个文本分块（均已添加 id）
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;⚠️ 注意：&lt;code&gt;chunk_size=100&lt;/code&gt; 在演示中偏小（导致分块数多），实际建议设为 &lt;strong&gt;500–1000 字符&lt;/strong&gt; 以平衡粒度与上下文完整性。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h4&gt;步骤 2：清洗非法 Unicode 字符（关键！）&lt;/h4&gt;
&lt;p&gt;在尝试保存到 Chroma 时，我们遇到了以下错误：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;UnicodeEncodeError: &apos;utf-8&apos; codec can&apos;t encode characters in position 795-798: surrogates not allowed
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是因为某些 PDF 解析后包含 &lt;strong&gt;非法代理字符（surrogate characters）&lt;/strong&gt;，而 ChromaDB（底层基于 Rust）要求所有字符串必须是合法 UTF-8。&lt;/p&gt;
&lt;h5&gt;解决方案：文本清洗函数&lt;/h5&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def clean_text(text: str) -&gt; str:
    &quot;&quot;&quot;移除无法编码为 UTF-8 的非法字符，并清理首尾空白&quot;&quot;&quot;
    return text.encode(&apos;utf-8&apos;, errors=&apos;ignore&apos;).decode(&apos;utf-8&apos;).strip(&apos;\n&apos;).strip()

# 对所有分块内容进行清洗
for doc in splits:
    doc.page_content = clean_text(doc.page_content)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;✅ 清洗后即可安全写入 Chroma。&lt;/p&gt;
&lt;hr&gt;
&lt;h4&gt;步骤 3：智能创建并填充 Chroma 向量库&lt;/h4&gt;
&lt;p&gt;为避免重复创建或覆盖问题，我们设计一个“智能初始化”函数：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_chroma import Chroma

def smart_vectorstore_creation(documents, embeddings, persist_directory=&quot;./chroma_db&quot;, collection_name=&quot;langchain&quot;):
    &quot;&quot;&quot;智能向量数据库创建：自动清理旧数据，分批写入&quot;&quot;&quot;
    print(f&quot;🔧 处理集合: {collection_name}&quot;)
    os.makedirs(persist_directory, exist_ok=True)
    
    try:
        # 尝试连接现有集合
        vectorstore = Chroma(
            collection_name=collection_name,
            persist_directory=persist_directory,
            embedding_function=embeddings
        )
        collection_info = vectorstore.get()
        if collection_info[&apos;ids&apos;]:
            print(f&quot;🗑️ 清理集合中的 {len(collection_info[&apos;ids&apos;])} 个文档&quot;)
            vectorstore.delete(ids=collection_info[&apos;ids&apos;])
        else:
            print(&quot;✅ 集合为空，无需清理&quot;)
    except Exception:
        print(f&quot;🆕 集合不存在，将创建新集合&quot;)
        vectorstore = None

    # 分批写入（Chroma 单次写入不宜过大）
    total_docs = len(documents)
    batch_size = 1000
    for i in range(0, total_docs, batch_size):
        batch = documents[i:i + batch_size]
        batch_num = i // batch_size + 1
        total_batches = (total_docs - 1) // batch_size + 1
        print(f&quot;📦 添加批次 {batch_num}/{total_batches}: {len(batch)} 个文档&quot;)
        
        if vectorstore is None and i == 0:
            vectorstore = Chroma.from_documents(
                documents=batch,
                embedding=embeddings,
                persist_directory=persist_directory,
                collection_name=collection_name
            )
        else:
            vectorstore.add_documents(batch)
    
    print(f&quot;✅ 完成！集合 &apos;{collection_name}&apos; 现有 {total_docs} 个文档&quot;)
    return vectorstore

# 执行存储
vectorstore = smart_vectorstore_creation(splits, embeddings, collection_name=&quot;robot_grasping_rag&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;执行日志&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🔧 处理集合: robot_grasping_rag
✅ 集合为空，无需清理
📦 添加批次 1/13: 1000 个文档
...
📦 添加批次 13/13: 729 个文档
✅ 完成！集合 &apos;robot_grasping_rag&apos; 现有 12729 个文档
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;💾 数据已持久化到 &lt;code&gt;./chroma_db/&lt;/code&gt; 目录，下次可直接加载复用，无需重新嵌入！&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;下一步：语义检索与问答&lt;/p&gt;
&lt;p&gt;现在，我们的 10 篇中文论文已被切分为 12,729 个清洗后的文本块，并成功存入 Chroma 向量数据库。接下来，我们将实现：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;用户自然语言查询的向量化&lt;/li&gt;
&lt;li&gt;Top-K 语义相似块检索&lt;/li&gt;
&lt;li&gt;将检索结果注入 LLM 生成最终答案&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;这是 RAG 的核心流程！&lt;/p&gt;
&lt;hr&gt;
&lt;h1&gt;Part4: 检索&lt;/h1&gt;
&lt;p&gt;向量数据库构建完成后，RAG 系统的核心能力之一——&lt;strong&gt;语义检索&lt;/strong&gt;——正式启用。这一步的目标是：&lt;strong&gt;根据用户自然语言问题，从海量文档块中精准召回最相关的上下文片段&lt;/strong&gt;，为后续大模型生成答案提供高质量依据。&lt;/p&gt;
&lt;h3&gt;4.1 常见检索策略对比&lt;/h3&gt;
&lt;p&gt;LangChain 支持多种检索方式，适用于不同场景：&lt;/p&gt;
&lt;p&gt;| 策略 | 配置示例 | 优势 | 劣势 | 适用场景 |
|------|--------|------|------|----------|
| &lt;strong&gt;相似度搜索 (Similarity)&lt;/strong&gt; | &lt;code&gt;search_type=&quot;similarity&quot;, k=5&lt;/code&gt; | 简单快速，直接返回最相关结果 | 可能返回高度重复内容 | 默认首选，通用问答 |
| &lt;strong&gt;最大边际相关性 (MMR)&lt;/strong&gt; | &lt;code&gt;search_type=&quot;mmr&quot;, k=5, fetch_k=20, lambda_mult=0.5&lt;/code&gt; | 平衡相关性与多样性 | 计算稍慢 | 需要多角度信息（如综述类问题） |
| &lt;strong&gt;相似度阈值过滤&lt;/strong&gt; | &lt;code&gt;search_type=&quot;similarity_score_threshold&quot;, score_threshold=0.5&lt;/code&gt; | 过滤低质量结果，保证精度 | 可能无结果返回 | 对答案可靠性要求极高 |&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;MMR 工作原理示意&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;以查询 &lt;strong&gt;“机器学习算法”&lt;/strong&gt; 为例：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;第一步&lt;/strong&gt;：先用相似度搜索获取 top-20 候选：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;#1: “深度学习是机器学习的一个分支...” （相似度 0.95）&lt;/li&gt;
&lt;li&gt;#2: “深度学习使用神经网络...” （0.94，与 #1 高度重合）&lt;/li&gt;
&lt;li&gt;#3: “决策树是一种机器学习算法...” （0.90）&lt;/li&gt;
&lt;li&gt;#4: “支持向量机(SVM)用于分类...” （0.88）&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;第二步&lt;/strong&gt;：MMR 选择时，优先选 #1（最相关），然后跳过 #2（太相似），转而选择 #3 和 #4 以增加&lt;strong&gt;主题多样性&lt;/strong&gt;。&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 在学术论文检索中，MMR 尤其有用——避免只返回同一章节的重复段落。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;4.2 加载向量库并初始化检索器&lt;/h3&gt;
&lt;p&gt;我们首先从磁盘加载已持久化的 Chroma 数据库：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
from langchain_chroma import Chroma

if os.path.exists(&quot;./chroma_db&quot;):
    vectorstore = Chroma(
        persist_directory=&quot;./chroma_db&quot;,
        embedding_function=embeddings
    )
    collection = vectorstore._collection
    print(f&quot;✅ 向量数据库加载成功！&quot;)
    print(f&quot;   已存在 {collection.count()} 个文档块&quot;)
    print(f&quot;   Collection 名称: {collection.name}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;✅ 向量数据库加载成功！
   已存在 12721 个文档块
   Collection 名称: langchain
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;📌 &lt;strong&gt;向量数据库 vs Collection 类比&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;向量数据库&lt;/strong&gt; ≈ MySQL 实例（如 &lt;code&gt;my_company_db&lt;/code&gt;）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Collection&lt;/strong&gt; ≈ 数据表（如 &lt;code&gt;papers&lt;/code&gt; 表）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;每条记录&lt;/strong&gt; = 文本块 + 向量 + 元数据（来源、页码、ID）&lt;/li&gt;
&lt;/ul&gt;
&lt;/blockquote&gt;
&lt;p&gt;接着创建基础检索器（返回 top-5）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;retriever = vectorstore.as_retriever(
    search_type=&quot;similarity&quot;,
    search_kwargs={&quot;k&quot;: 5}
)
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;4.3 执行检索并清理乱码文本&lt;/h3&gt;
&lt;p&gt;PDF 解析常引入排版残留字符（如 &lt;code&gt;&lt;/code&gt;, &lt;code&gt;&lt;/code&gt;, 多余换行等），需在展示前清洗。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;增强版中文文本清洗函数&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import re

def clean_text(text: str) -&gt; str:
    &quot;&quot;&quot;
    移除中文字符之间的任意空白（包括全角空格、不间断空格等），
    同时保留英文/数字间的正常空格。
    &quot;&quot;&quot;
    chinese_char = r&apos;[\u4e00-\u9fff]&apos;
    any_whitespace = r&apos;[\s\u00A0\u2000-\u200F\u2028-\u202F\u3000]+&apos;
    pattern = f&apos;({chinese_char}){any_whitespace}(?={chinese_char})&apos;
    result = re.sub(pattern, r&apos;\1&apos;, text)  # 中文间空白直接删除
    result = re.sub(r&apos;\s+&apos;, &apos; &apos;, result)   # 其他区域合并多余空格
    return result.strip()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;检索示例：查询“什么是视触觉？”&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;query = &quot;什么是视触觉？&quot;
docs = retriever.invoke(query)

for i, doc in enumerate(docs, 1):
    cleaned_content = clean_text(doc.page_content)
    source = doc.metadata.get(&apos;source_file&apos;, &apos;N/A&apos;)
    doc_id = doc.metadata.get(&apos;id&apos;, &apos;N/A&apos;)
    print(f&quot;\n📄 结果 {i}:&quot;)
    print(f&quot;内容: {cleaned_content[:300]}...&quot;)
    print(f&quot;来源: {source}&quot;)
    print(f&quot;ID: {doc_id}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出节选&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📄 结果 1:
内容: perception 皮肤。使用一种被称为Gelsight的触觉传感器作为机器人的指尖触觉感受器...
来源: 基于触觉图像序列的机器人抓取目标状态感知_韩筱.pdf
ID: 基于触觉图像序列的机器人抓取目标状态感知_韩筱_p23_c383

📄 结果 4:
内容: 人们对一个物体的描述往往是从多个角度去描述的，依靠的是多个器官的共同感知，其中最重要的是视觉感知和触觉感知...
来源: 基于视触觉融合的机械手分类抓取方法研究_余航.pdf
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;✅ 成功召回多篇论文中关于“视触觉融合”的定义与技术描述！&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;5.4 高级检索：自定义检索器&lt;/h3&gt;
&lt;p&gt;基础检索有时不够精准。我们可通过以下四步构建&lt;strong&gt;高级检索流水线&lt;/strong&gt;：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;扩大候选范围&lt;/strong&gt;（如 k=10）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;按元数据过滤&lt;/strong&gt;（如仅限某作者论文）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;重排序&lt;/strong&gt;（用更精细打分模型）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;返回 Top-K&lt;/strong&gt;&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def custom_retriever(question: str) -&gt; list:
    # 1. 扩大检索范围
    base_docs = vectorstore.similarity_search(question, k=10)
    
    # 2. 按来源过滤（例如只看“张静”的论文）
    filtered_docs = [
        doc for doc in base_docs
        if &quot;张静&quot; in doc.metadata.get(&apos;source_file&apos;, &apos;&apos;)
    ]
    
    # 3. 重排序（此处简化为关键词匹配，实际可用 BCE-Reranker）
    def calculate_relevance_score(doc, query):
        return sum(1 for word in query if word in doc.page_content)
    
    scored_docs = [(doc, calculate_relevance_score(doc, question)) for doc in filtered_docs]
    scored_docs.sort(key=lambda x: x[1], reverse=True)
    
    # 4. 返回 top-5
    return [doc for doc, _ in scored_docs[:5]]
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;🔍 &lt;strong&gt;进阶建议&lt;/strong&gt;：可集成 &lt;code&gt;bce-reranker-base_v1&lt;/code&gt;（百度开源重排序模型）对初检结果重新打分，显著提升 top-1 准确率。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;5.5 评估检索质量&lt;/h3&gt;
&lt;p&gt;仅靠人工判断不够客观。我们设计简单指标评估检索器性能：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def evaluate_retrieval(retriever, test_cases):
    metrics = {&quot;precision&quot;: [], &quot;recall&quot;: []}
    for query, expected_ids in test_cases:
        retrieved = retriever.invoke(query)
        retrieved_ids = [d.metadata[&apos;id&apos;] for d in retrieved]
        
        relevant = set(retrieved_ids) &amp;#x26; set(expected_ids)
        precision = len(relevant) / len(retrieved_ids) if retrieved_ids else 0
        recall = len(relevant) / len(expected_ids) if expected_ids else 0
        
        metrics[&quot;precision&quot;].append(precision)
        metrics[&quot;recall&quot;].append(recall)
    
    return {
        &quot;avg_precision&quot;: sum(metrics[&quot;precision&quot;]) / len(metrics[&quot;precision&quot;]),
        &quot;avg_recall&quot;: sum(metrics[&quot;recall&quot;]) / len(metrics[&quot;recall&quot;])
    }

# 测试用例
test_cases = [
    (&quot;什么是抓取检测？&quot;, [
        &quot;基于视触感知协同的多指灵巧手抓取方法研究_张静_p164_c3266&quot;,
        &quot;基于视触感知协同的机器人抓取技术研究_祝会龙_p39_c500&quot;,
        # ...其他相关 ID
    ]),
    (&quot;什么是机械臂？&quot;, [&quot;基于触觉图像序列的机器人抓取目标状态感知_韩筱_p25_c407&quot;]),
]

results = evaluate_retrieval(retriever, test_cases)
print(f&quot;平均精确率: {results[&apos;avg_precision&apos;]:.2%}&quot;)
print(f&quot;平均召回率: {results[&apos;avg_recall&apos;]:.2%}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;平均精确率: 40.00%
平均召回率: 87.50%
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;🔍 &lt;strong&gt;分析&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;高召回率&lt;/strong&gt;：说明系统能覆盖大部分相关文档（不错过重要信息）；&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;较低精确率&lt;/strong&gt;：部分返回结果不相关，可能因 chunk_size 过小或 PDF 噪声干扰。&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;✅ &lt;strong&gt;优化方向&lt;/strong&gt;：增大分块粒度、引入重排序、添加元数据过滤（如排除目录页）。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;下一步：整合 LLM 生成答案&lt;/p&gt;
&lt;p&gt;现在，我们已能从 10 篇中文论文中高效检索相关内容。下一步将把这些上下文注入大语言模型（如 Qwen、ChatGLM），让其基于&lt;strong&gt;真实文献&lt;/strong&gt;回答用户问题，真正实现 &lt;strong&gt;“有据可依” 的智能问答系统&lt;/strong&gt;。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part5: 生成：基于检索结果让 LLM 生成精准答案&lt;/h2&gt;
&lt;p&gt;检索到相关文档后，RAG 系统进入&lt;strong&gt;最终环节——生成&lt;/strong&gt;。这一步的目标是：&lt;strong&gt;将检索到的上下文信息与大语言模型的知识相结合，生成准确、流畅且基于事实的答案&lt;/strong&gt;。&lt;/p&gt;
&lt;h3&gt;5.1 本地问答系统架构设计&lt;/h3&gt;
&lt;p&gt;我们构建了一个完整的本地问答系统 &lt;code&gt;LocalQASystem&lt;/code&gt;，核心组件包括：&lt;/p&gt;
&lt;p&gt;| 组件 | 技术选型 | 作用 | 关键配置 |
|------|---------|------|----------|
| &lt;strong&gt;对话模型&lt;/strong&gt; | Qwen-7B-Chat-Int8 + vLLM | 生成答案 | GPTQ量化，推理加速 |
| &lt;strong&gt;嵌入模型&lt;/strong&gt; | bce-embedding-base_v1 | 查询向量化 | 与构建时保持一致 |
| &lt;strong&gt;向量数据库&lt;/strong&gt; | ChromaDB | 存储和检索 | 持久化存储 |&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;系统初始化流程&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class LocalQASystem:
    def __init__(self, model_dir, chroma_db_path=&quot;./chroma_db&quot;, embeddings_model_path=None):
        self.model_dir = self._setup_model_dir(model_dir)  # 模型路径处理
        self.chroma_db_path = chroma_db_path
        self.embeddings_model_path = embeddings_model_path
        self._setup_model()           # 初始化vLLM
        self._setup_exact_embeddings() # 关键：使用相同嵌入模型
        self._setup_vectorstore()     # 加载向量库
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;初始化输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📁 使用本地模型: ../Qwen-vllm/Models/Qwen/Qwen-7B-Chat-Int8
🤖 vLLM模型初始化完成
✅ 使用嵌入模型: ./Models/maidalun/bce-embedding-base_v1
🔢 嵌入模型维度: 768
🗂️  向量数据库加载成功: ./chroma_db
📄 文档数量: 12721
✅ 问答系统初始化完成
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;⚠️ &lt;strong&gt;关键点&lt;/strong&gt;：必须使用&lt;strong&gt;与构建向量库时完全相同的嵌入模型&lt;/strong&gt;，否则向量空间不一致会导致检索失败！&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;5.2 问答流程四步走&lt;/h3&gt;
&lt;h4&gt;步骤1：检索相关文档&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def retrieve_with_exact_embedding(self, query, n_results=5):
    &quot;&quot;&quot;使用精确匹配的嵌入模型进行检索&quot;&quot;&quot;
    query_embedding = self.embeddings.embed_query(query)  # 关键步骤！
    results = self.collection.query(
        query_embedding=[query_embedding],  # 使用相同模型生成向量
        n_results=n_results
    )
    return results
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;步骤2：上下文清洗与预处理&lt;/h4&gt;
&lt;p&gt;PDF解析常产生格式问题，需专门清洗：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def _clean_context(self, text):
    &quot;&quot;&quot;清洗PDF解析产生的格式问题&quot;&quot;&quot;
    # 1. 合并被错误分割的文字（如：\n运\n动\n → 运动）
    cleaned = re.sub(r&apos;(?&amp;#x3C;=[^\s])\n(?=[^\s])&apos;, &apos;&apos;, text)
    # 2. 处理多余空白和空行
    cleaned = re.sub(r&apos;\n\s+\n&apos;, &apos;\n\n&apos;, cleaned)
    return cleaned.strip()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;清洗效果对比&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;清洗前&lt;/strong&gt;： &lt;code&gt;机\n器\n人\n 抓\n取\n 技\n术\n 研\n究&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;清洗后&lt;/strong&gt;： &lt;code&gt;机器人抓取技术研究&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;步骤3：构建精准提示词（Prompt Engineering）&lt;/h3&gt;
&lt;p&gt;采用ChatML格式，明确约束LLM行为：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def _build_prompt(self, question, context):
    cleaned_context = self._clean_context(context)
    
    return f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
			你是一个专业的AI助手。请严格基于以下上下文信息回答问题：
			
			{cleaned_context}
			
			请遵循以下规则：
			1. 只使用上下文中的信息回答
			2. 如果上下文不包含相关信息，请回答&quot;我不知道&quot;
			3. 保持回答准确、简洁
			4. 不要编造信息&amp;#x3C;|im_end|&gt;
			&amp;#x3C;|im_start|&gt;user
			{question}&amp;#x3C;|im_end|&gt;
			&amp;#x3C;|im_start|&gt;assistant
			&quot;&quot;&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;提示词设计原则&lt;/strong&gt;：明确角色、限定知识范围、设定回答规则、防止幻觉。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h4&gt;步骤4：vLLM高效生成答案&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def _generate_answer(self, prompt, max_tokens, temperature):
    sampling_params = SamplingParams(
        max_tokens=max_tokens,      # 控制生成长度
        temperature=temperature,    # 控制随机性（0.1-0.3更确定）
        top_p=0.8,                  # 核采样，提高相关性
        stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]  # 停止标记
    )
    
    outputs = self.llm.generate([prompt], sampling_params)
    return outputs[0].outputs[0].text
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;5.3 实战测试：验证系统效果&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;questions = [
    &quot;什么是机械臂？它有什么能力？&quot;,
    &quot;解释一下滑动检测的基本概念&quot;,
]

print(&quot;🚀 开始测试问答系统&quot;)
print(&quot;=&quot; * 70)

for i, question in enumerate(questions, 1):
    print(f&quot;\n🎯 问题 {i}: {question}&quot;)
    print(&quot;-&quot; * 50)
    
    result = qa_system.ask(question, max_tokens=400, temperature=0.3)
    print(f&quot;💡 答案: {result[&apos;answer&apos;]}&quot;)
    print(f&quot;📚 参考文档: {result[&apos;sources&apos;]} 个&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 开始测试问答系统
======================================================================

🎯 问题 1: 什么是机械臂？它有什么能力？
--------------------------------------------------
🔍 查询嵌入维度: 768
✅ 检索到 5 个相关文档
💡 答案: 机械臂是一种可以自动执行任务的机器人手臂，它具有精确度高、可重复性好、负载能力强、工作半径大、自由度多等能力。
📚 参考文档: 5 个

🎯 问题 2: 解释一下滑动检测的基本概念
--------------------------------------------------
🔍 查询嵌入维度: 768
✅ 检索到 5 个相关文档  
💡 答案: 滑动检测是一种用于检测物体是否发生滑动的算法。它通过比较帧间的标准差来判断物体是否发生滑动。如果帧间的标准差超过设定的阈值，那么就认为物体发生了滑动。滑动检测算法通常用于机械臂的抓取任务中，以防止物体在抓取过程中滑落。
📚 参考文档: 5 个
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;✅ &lt;strong&gt;成功指标&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;答案准确基于论文内容（非模型固有知识）&lt;/li&gt;
&lt;li&gt;回答简洁专业，符合学术规范&lt;/li&gt;
&lt;li&gt;检索到多个相关文档作为支撑&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;h3&gt;5.4 进阶话题：Chain Types 详解&lt;/h3&gt;
&lt;p&gt;LangChain 提供了多种文档处理链类型，适用于不同场景：&lt;/p&gt;
&lt;h4&gt;5.4.1 四种链类型对比&lt;/h4&gt;
&lt;p&gt;| 链类型 | 工作原理 | 优点 | 缺点 | 适用场景 |
|--------|---------|------|------|----------|
| &lt;strong&gt;Stuff&lt;/strong&gt; | 所有文档拼接后一次性提问 | 简单高效，一次LLM调用 | 文档多时会超长 | 文档少(&amp;#x3C;10)，默认首选 |
| &lt;strong&gt;Map-Reduce&lt;/strong&gt; | 先分别处理每个文档，再合并答案 | 可处理大量文档，支持并行 | 成本高，丢失文档关联 | 文档非常多时 |
| &lt;strong&gt;Refine&lt;/strong&gt; | 迭代处理，用后续文档改进答案 | 答案质量高，保持关联 | 顺序敏感，不能并行 | 需要高质量答案 |
| &lt;strong&gt;Map-Rerank&lt;/strong&gt; | 对每个文档生成答案并评分，选最佳 | 自动选择最相关答案 | 每个文档都需LLM调用 | 找最准确答案 |&lt;/p&gt;
&lt;h4&gt;5.4.2 链类型选择指南&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 根据场景选择合适的链类型
chain_configs = {
    &quot;default&quot;: {&quot;chain_type&quot;: &quot;stuff&quot;, &quot;k&quot;: 5},
    &quot;many_docs&quot;: {&quot;chain_type&quot;: &quot;map_reduce&quot;, &quot;k&quot;: 20},
    &quot;high_quality&quot;: {&quot;chain_type&quot;: &quot;refine&quot;, &quot;k&quot;: 8},
    &quot;most_relevant&quot;: {&quot;chain_type&quot;: &quot;map_rerank&quot;, &quot;k&quot;: 10}
}

def select_chain_type(scenario, document_count):
    if document_count &gt; 15:
        return &quot;map_reduce&quot;  # 文档太多，需要分治
    elif scenario == &quot;precision_critical&quot;:
        return &quot;map_rerank&quot;  # 精度要求高
    elif scenario == &quot;quality_first&quot;:
        return &quot;refine&quot;      # 质量优先
    else:
        return &quot;stuff&quot;       # 默认选择
&lt;/code&gt;&lt;/pre&gt;
&lt;blockquote&gt;
&lt;p&gt;🔍 &lt;strong&gt;实践建议&lt;/strong&gt;：从 &lt;code&gt;stuff&lt;/code&gt; 开始测试，如遇上下文长度问题再切换到 &lt;code&gt;map_reduce&lt;/code&gt;。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;hr&gt;
&lt;h3&gt;5.5 生成质量评估与优化&lt;/h3&gt;
&lt;h3&gt;评估指标&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;相关性&lt;/strong&gt;：答案是否直接回应问题&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;准确性&lt;/strong&gt;：是否基于提供的上下文&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;完整性&lt;/strong&gt;：是否覆盖问题的关键方面&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;可读性&lt;/strong&gt;：语言是否流畅自然&lt;/li&gt;
&lt;/ul&gt;
&lt;h4&gt;常见问题与解决方案&lt;/h4&gt;
&lt;p&gt;| 问题现象 | 可能原因 | 解决方案 |
|---------|---------|----------|
| 答案与上下文无关 | 提示词约束不够强 | 加强system提示词约束 |
| 出现&quot;幻觉&quot;信息 | 温度参数过高 | 降低temperature(0.1-0.3) |
| 答案过于简短 | max_tokens设置太小 | 适当增加生成长度 |
| 包含无关内容 | 检索文档不相关 | 优化检索策略，增加重排序 |&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 质量优化配置
optimized_params = {
    &quot;temperature&quot;: 0.2,      # 降低随机性，提高确定性
    &quot;top_p&quot;: 0.85,           # 平衡相关性和多样性
    &quot;max_tokens&quot;: 512,        # 保证答案完整
    &quot;stop_tokens&quot;: [&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;\n\n&quot;]  # 合理终止
}
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;总结：RAG 流程闭环&lt;/h2&gt;
&lt;p&gt;至此，我们完成了完整的 RAG 流水线：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;文档处理&lt;/strong&gt; → PDF解析、文本分块、向量化&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;向量存储&lt;/strong&gt; → ChromaDB持久化存储&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;语义检索&lt;/strong&gt; → 相似度搜索、MMR多样性优化&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;答案生成&lt;/strong&gt; → 提示词工程、vLLM高效推理&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;✅ &lt;strong&gt;核心成就&lt;/strong&gt;：构建了一个能够基于10篇中文学术论文进行&lt;strong&gt;有据可查、准确可靠&lt;/strong&gt;的智能问答系统。系统完全在本地运行，保障数据安全，且答案可追溯至具体文献来源。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>RAG实战（三）路由与查询构建</title><link>https://astro-pure.js.org/blog/rag_blogs/rag_blogs-3</link><guid isPermaLink="true">https://astro-pure.js.org/blog/rag_blogs/rag_blogs-3</guid><description>记录RAG的学习。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/RAG_Learning&quot;&gt;Github地址&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;RAG路由与查询构建：智能检索的核心技术&lt;/h2&gt;
&lt;p&gt;在前两章中，我们学习了RAG系统的基础和查询优化技术。但是，当面对多个数据源或需要结构化查询时，如何智能地选择正确的数据源和构建合适的查询呢？本章将介绍路由机制和查询构建技术。&lt;/p&gt;
&lt;h3&gt;为什么需要路由和查询构建？&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;实际场景中的挑战&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;场景1: 多个数据源&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 企业知识库中的多个数据源
数据源1: 技术文档数据库
数据源2: 用户手册数据库  
数据源3: FAQ知识库
数据源4: API参考文档

用户查询: &quot;如何使用Python SDK连接数据库？&quot;

# 应该查询哪个数据源？
→ 单一数据源可能不够
→ 查询所有数据源效率低
→ 需要智能路由机制
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;场景2: 复杂查询条件&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 向量数据库包含:
- 文档内容 (embedding)
- 元数据: 
  - 作者
  - 发布日期
  - 文档类型
  - 标签

用户查询: &quot;找出2023年发布的关于机器学习的文章&quot;

# 需要同时考虑:
→ 语义相似度 (机器学习)
→ 结构化条件 (日期 &gt;= 2023-01-01)
→ 需要查询构建技术
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;本章内容概览&lt;/h3&gt;
&lt;p&gt;| 技术 | 核心功能 | 适用场景 | 复杂度 |
|------|---------|----------|--------|
| &lt;strong&gt;逻辑路由&lt;/strong&gt; | 基于规则的路由 | 确定性路由 | ⭐ |
| &lt;strong&gt;语义路由&lt;/strong&gt; | 基于LLM的路由 | 灵活路由 | ⭐⭐ |
| &lt;strong&gt;结构化查询&lt;/strong&gt; | 构建filter条件 | 带元数据查询 | ⭐⭐ |
| &lt;strong&gt;自查询检索器&lt;/strong&gt; | 自动分离查询意图 | 复杂查询 | ⭐⭐⭐ |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;环境准备与数据源设置&lt;/h3&gt;
&lt;h4&gt;创建多数据源环境&lt;/h4&gt;
&lt;p&gt;首先，我们需要为系统准备多个数据源。我们将创建两个集合：学术论文集合和网页内容集合。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_community.document_loaders import WebBaseLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import asyncio
from concurrent.futures import ThreadPoolExecutor
import time

# 初始化嵌入模型
embeddings = HuggingFaceEmbeddings(
    model_name=&quot;./Models/maidalun/bce-embedding-base_v1&quot;,
    model_kwargs={&quot;device&quot;: &quot;cuda&quot;},
    encode_kwargs={&quot;normalize_embeddings&quot;: True}
)

# 创建 web_content 集合
vectorstore = Chroma(
    collection_name=&quot;web_content&quot;,  # 指定集合名称
    persist_directory=&quot;./chroma_db&quot;,
    embedding_function=embeddings
)

print(&quot;✅ 创建 web_content 集合成功&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;并行加载网页数据&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def load_single_url(url: str):
    &quot;&quot;&quot;加载单个URL&quot;&quot;&quot;
    try:
        loader = WebBaseLoader(url)
        docs = loader.load()
        print(f&quot;✅ 成功加载: {url} -&gt; {len(docs)} 个文档&quot;)
        return docs
    except Exception as e:
        print(f&quot;❌ 加载失败: {url} -&gt; 错误: {e}&quot;)
        return []

async def load_urls_parallel(urls: list, max_workers: int = 5):
    &quot;&quot;&quot;并行加载多个URL&quot;&quot;&quot;
    print(f&quot;🚀 开始并行加载 {len(urls)} 个URL...&quot;)
    
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 创建并行任务
        tasks = [
            loop.run_in_executor(executor, load_single_url, url)
            for url in urls
        ]
        
        # 并行执行
        all_docs = await asyncio.gather(*tasks)
    
    # 合并所有文档
    flat_docs = []
    for docs in all_docs:
        flat_docs.extend(docs)
    
    print(f&quot;📊 总共加载了 {len(flat_docs)} 个文档&quot;)
    return flat_docs

def add_docs_to_vectorstore(docs: list):
    &quot;&quot;&quot;将文档添加到向量数据库&quot;&quot;&quot;
    if not docs:
        print(&quot;⚠️ 没有文档可添加&quot;)
        return
    
    print(&quot;💾 正在将文档添加到向量数据库...&quot;)
    
    # 添加文档到集合
    vectorstore.add_documents(docs)
    
    # 持久化保存
    vectorstore.persist()
    
    print(f&quot;✅ 成功添加 {len(docs)} 个文档到 web_content 集合&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;数据加载实战&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 使用示例
async def main():
    # URL列表
    url_list = [
        &quot;https://blog.csdn.net/qq_40081208/article/details/111053208&quot;,
        &quot;https://blog.csdn.net/Yyuan12345678/article/details/142108850&quot;,
        &quot;https://blog.csdn.net/WhiffeYF/article/details/111031270&quot;,
        &quot;https://www.hanspub.org/journal/PaperInformation?paperID=84081&quot;,
        &quot;https://blog.csdn.net/WhiffeYF/article/details/110829105&quot;
    ]
    
    if not url_list:
        print(&quot;⚠️ 请先在 url_list 中添加URL链接&quot;)
        return
    
    # 1. 并行加载所有URL
    all_docs = await load_urls_parallel(url_list)
    
    # 2. 添加到向量数据库
    add_docs_to_vectorstore(all_docs)
    
    # 3. 验证添加结果
    print(&quot;\n🔍 验证数据库内容:&quot;)
    collection_info = vectorstore._client.get_collection(&quot;web_content&quot;)
    print(f&quot;集合中的文档数量: {collection_info.count()}&quot;)
    
    # 4. 测试检索
    test_query = &quot;抓取检测&quot;
    results = vectorstore.similarity_search(test_query, k=2)
    print(f&quot;\n🔎 测试检索 &apos;{test_query}&apos;:&quot;)
    for i, doc in enumerate(results, 1):
        print(f&quot;   {i}. 来源: {doc.metadata.get(&apos;source&apos;, &apos;N/A&apos;)}&quot;)
        print(f&quot;      内容预览: {doc.page_content[:100]}...&quot;)

await main()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;✅ 创建 web_content 集合成功
🚀 开始并行加载 5 个URL...
✅ 成功加载: https://www.hanspub.org/journal/PaperInformation?paperID=84081 -&gt; 1 个文档
✅ 成功加载: https://blog.csdn.net/WhiffeYF/article/details/110829105 -&gt; 1 个文档
✅ 成功加载: https://blog.csdn.net/Yyuan12345678/article/details/142108850 -&gt; 1 个文档
✅ 成功加载: https://blog.csdn.net/WhiffeYF/article/details/111031270 -&gt; 1 个文档
✅ 成功加载: https://blog.csdn.net/qq_40081208/article/details/111053208 -&gt; 1 个文档
📊 总共加载了 5 个文档
💾 正在将文档添加到向量数据库...
✅ 成功添加 5 个文档到 web_content 集合

🔍 验证数据库内容:
集合中的文档数量: 5

🔎 测试检索 &apos;抓取检测&apos;:
   1. 来源: https://blog.csdn.net/qq_40081208/article/details/111053208
      内容预览: 抓取检测之Dex-Net 2.0_dexnet-CSDN博客...
   2. 来源: https://blog.csdn.net/WhiffeYF/article/details/111031270
      内容预览: 械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状...
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;Part 1: 逻辑路由 - Logical Routing&lt;/h2&gt;
&lt;h3&gt;1.1 核心概念&lt;/h3&gt;
&lt;p&gt;逻辑路由使用基于规则的方法来决定将查询发送到哪个数据源。它通过LLM理解查询内容，然后根据预定义的规则选择合适的数据源。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;工作原理&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询
    ↓
LLM分析查询意图
    ↓
匹配预定义的路由规则
    ↓
选择目标数据源
    ↓
执行检索
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.2 基础路由实现&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from vllm import SamplingParams

# 定义路由提示词（ChatML格式）
route_prompt_template = &quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
		你是一个路由助手，负责将用户查询发送到正确的数据源。
		
		可用的数据源:
		- langchain: 学术论文和技术文档，包含抓取检测、滑动检测相关的研究论文、技术文档
		- web_content: 网页内容，包含CSDN博客、技术教程、实践指南等网页文章
		
		请分析用户查询，返回最合适的数据源名称(只返回名称，不要其他内容)。&amp;#x3C;|im_end|&gt;
		&amp;#x3C;|im_start|&gt;user
		{question}&amp;#x3C;|im_end|&gt;
		&amp;#x3C;|im_start|&gt;assistant
		数据源：&quot;&quot;&quot;

def route_chain(question: str, llm) -&gt; str:
    &quot;&quot;&quot;基础路由链&quot;&quot;&quot;
    prompt = route_prompt_template.format(question=question)
    
    # 使用 vLLM 的正确调用方式
    sampling_params = SamplingParams(
        temperature=0.1,
        top_p=0.9,
        max_tokens=50,
        stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
    )
    
    # 生成响应
    outputs = llm.generate([prompt], sampling_params)
    
    # 提取结果
    if outputs and len(outputs) &gt; 0:
        output = outputs[0]
        if hasattr(output, &apos;outputs&apos;) and output.outputs:
            return output.outputs[0].text.strip()
    
    return &quot;web_content&quot;  # 默认返回

# 测试路由
question1 = &quot;如何在数据库中查询？&quot;
route1 = route_chain(question1, llm)
print(f&quot;Query: {question1}&quot;)
print(f&quot;Route: {route1}\n&quot;)

question2 = &quot;今天的天气怎么样？&quot;
route2 = route_chain(question2, llm)
print(f&quot;Query: {question2}&quot;)
print(f&quot;Route: {route2}\n&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  3.42it/s]
Query: 如何在数据库中查询？
Route: 数据库

Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00, 18.21it/s]
Query: 今天的天气怎么样？
Route: web_content
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.3 完整逻辑路由系统&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.vectorstores import Chroma
from typing import Dict, Any, List

class LogicalRouter:
    &quot;&quot;&quot;完整的逻辑路由器&quot;&quot;&quot;
    
    def __init__(self, llm, embeddings, persist_directory=&quot;./chroma_db&quot;):
        self.llm = llm
        self.embeddings = embeddings
        self.persist_directory = persist_directory
        self.routes = {}
        self.route_descriptions = {}
        
        # 初始化集合
        self._init_collections()
        print(&quot;✅ 逻辑路由器初始化完成&quot;)
    
    def _init_collections(self):
        &quot;&quot;&quot;初始化向量数据库集合&quot;&quot;&quot;
        # 1. langchain 集合（论文）
        self.langchain_collection = Chroma(
            collection_name=&quot;langchain&quot;,
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
        
        # 2. web_content 集合（网页内容）
        self.web_content_collection = Chroma(
            collection_name=&quot;web_content&quot;,
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
        
        # 自动添加路由
        self.add_route(&quot;langchain&quot;, self.langchain_collection.as_retriever(), 
                      &quot;学术论文和技术文档，包含抓取检测、滑动检测相关的研究论文、技术文档&quot;)
        self.add_route(&quot;web_content&quot;, self.web_content_collection.as_retriever(), 
                      &quot;网页内容，包含CSDN博客、技术教程、实践指南等网页文章&quot;)
    
    def add_route(self, name: str, retriever, description: str = None):
        &quot;&quot;&quot;添加路由&quot;&quot;&quot;
        self.routes[name] = retriever
        if description:
            self.route_descriptions[name] = description
    
    def _call_llm(self, prompt: str, max_tokens: int = 100) -&gt; str:
        &quot;&quot;&quot;调用LLM进行路由决策&quot;&quot;&quot;
        sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            max_tokens=max_tokens,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        
        if outputs and outputs[0].outputs:
            return outputs[0].outputs[0].text.strip()
        return &quot;&quot;
    
    def route(self, question: str) -&gt; tuple:
        &quot;&quot;&quot;执行路由决策&quot;&quot;&quot;
        # 构建路由提示词
        route_descriptions = self._get_route_descriptions()
        
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个路由助手，负责将用户查询发送到正确的数据源。
            
            可用的数据源：
            {route_descriptions}
            
            请分析用户查询，返回最合适的数据源名称（只返回名称，不要其他内容）。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            问题：{question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            数据源：&quot;&quot;&quot;
        
        # 调用LLM进行路由决策
        route_name = self._call_llm(prompt, max_tokens=50)
        route_name = route_name.strip().lower()
        
        print(f&quot;🤖 LLM路由决策: &apos;{route_name}&apos;&quot;)
        
        # 检查路由是否存在，支持模糊匹配
        if route_name not in self.routes:
            for name in self.routes.keys():
                if name in route_name or route_name in name:
                    route_name = name
                    print(f&quot;🔄 模糊匹配到: {route_name}&quot;)
                    break
            else:
                # 默认回退
                print(f&quot;⚠️ 未找到路由: {route_name}，使用默认路由: web_content&quot;)
                route_name = &quot;web_content&quot;
        
        return route_name, self.routes[route_name]
    
    def _get_route_descriptions(self) -&gt; str:
        &quot;&quot;&quot;获取路由描述&quot;&quot;&quot;
        descriptions = []
        for name, retriever in self.routes.items():
            desc = self.route_descriptions.get(name, f&apos;{name}数据源&apos;)
            descriptions.append(f&quot;- {name}: {desc}&quot;)
        return &quot;\n&quot;.join(descriptions)
    
    def query(self, question: str, k: int = 4) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;执行完整查询&quot;&quot;&quot;
        print(f&quot;🎯 查询: {question}&quot;)
        
        # 路由到正确的数据源
        route_name, retriever = self.route(question)
        print(f&quot;📍 路由到: {route_name}&quot;)
        
        # 执行检索
        docs = retriever.get_relevant_documents(question, k=k)
        print(f&quot;📚 检索到 {len(docs)} 个文档&quot;)
        
        return {
            &quot;route&quot;: route_name,
            &quot;documents&quot;: docs,
            &quot;question&quot;: question,
            &quot;collection_size&quot;: self._get_collection_size(route_name)
        }
    
    def _get_collection_size(self, collection_name: str) -&gt; int:
        &quot;&quot;&quot;获取集合大小&quot;&quot;&quot;
        try:
            if collection_name == &quot;langchain&quot;:
                collection = self.langchain_collection._client.get_collection(&quot;langchain&quot;)
            else:
                collection = self.web_content_collection._client.get_collection(&quot;web_content&quot;)
            return collection.count()
        except:
            return 0
    
    def get_collection_info(self) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;获取集合信息&quot;&quot;&quot;
        info = {}
        for name in [&quot;langchain&quot;, &quot;web_content&quot;]:
            try:
                if name == &quot;langchain&quot;:
                    collection = self.langchain_collection._client.get_collection(&quot;langchain&quot;)
                else:
                    collection = self.web_content_collection._client.get_collection(&quot;web_content&quot;)
                info[name] = {
                    &quot;document_count&quot;: collection.count(),
                    &quot;description&quot;: self.route_descriptions.get(name, &apos;N/A&apos;)
                }
            except Exception as e:
                info[name] = {&quot;error&quot;: str(e)}
        
        return info
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.4 逻辑路由实战测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 初始化路由器
router = LogicalRouter(llm, embeddings, persist_directory=&quot;./chroma_db&quot;)

# 查看集合信息
print(&quot;📊 集合信息:&quot;)
info = router.get_collection_info()
for name, data in info.items():
    if &quot;document_count&quot; in data:
        print(f&quot;   {name}: {data[&apos;document_count&apos;]} 个文档 - {data[&apos;description&apos;]}&quot;)

# 测试路由查询
print(&quot;\n&quot; + &quot;=&quot;*50)
result = router.query(&quot;网页文档中关于抓取检测的介绍&quot;)
print(f&quot;路由结果: {result[&apos;route&apos;]}&quot;)
print(f&quot;文档数量: {len(result[&apos;documents&apos;])}&quot;)

# 显示文档预览
for i, doc in enumerate(result[&apos;documents&apos;][:2], 1):
    source = doc.metadata.get(&apos;source&apos;, &apos;未知来源&apos;)
    print(f&quot;  {i}. 来源: {source}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📊 集合信息:
   langchain: 12721 个文档 - 学术论文和技术文档，包含抓取检测、滑动检测相关的研究论文、技术文档
   web_content: 5 个文档 - 网页内容，包含CSDN博客、技术教程、实践指南等网页文章

🎯 查询: 网页文档中关于抓取检测的介绍
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.87it/s]
🤖 LLM路由决策: &apos;web_content&apos;
📍 路由到: web_content
📚 检索到 4 个文档
路由结果: web_content
文档数量: 4
  1. 来源: https://blog.csdn.net/qq_40081208/article/details/111053208
  2. 来源: https://blog.csdn.net/WhiffeYF/article/details/110829105
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.5 带回退机制的路由优化&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class LogicalRouterWithFallback:
    &quot;&quot;&quot;带回退机制的逻辑路由器&quot;&quot;&quot;
    
    def __init__(self, llm, embeddings, fallback_route: str = &quot;web_content&quot;, persist_directory=&quot;./chroma_db&quot;):
        self.llm = llm
        self.embeddings = embeddings
        self.persist_directory = persist_directory
        self.fallback_route = fallback_route
        self.routes = {}
        self.route_descriptions = {}
        
        # 初始化集合
        self._init_collections()
        print(&quot;✅ 带回退机制的逻辑路由器初始化完成&quot;)
    
    def _init_collections(self):
        &quot;&quot;&quot;初始化集合（与之前相同）&quot;&quot;&quot;
        self.langchain_collection = Chroma(
            collection_name=&quot;langchain&quot;,
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
        
        self.web_content_collection = Chroma(
            collection_name=&quot;web_content&quot;,
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
        
        # 添加路由
        self.add_route(&quot;langchain&quot;, self.langchain_collection.as_retriever(), 
                      &quot;学术论文和技术文档&quot;)
        self.add_route(&quot;web_content&quot;, self.web_content_collection.as_retriever(), 
                      &quot;网页内容&quot;)
    
    def add_route(self, name: str, retriever, description: str = None):
        &quot;&quot;&quot;添加路由&quot;&quot;&quot;
        self.routes[name] = retriever
        if description:
            self.route_descriptions[name] = description
    
    def route(self, question: str) -&gt; tuple:
        &quot;&quot;&quot;执行路由，带回退机制&quot;&quot;&quot;
        try:
            # 构建路由提示词
            route_descriptions = self._get_route_descriptions()
            
            prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
                你是一个路由助手，负责将用户查询发送到正确的数据源。
                
                可用的数据源：
                {route_descriptions}
                
                请分析用户查询，返回最合适的数据源名称（只返回名称，不要其他内容）。&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;user
                问题：{question}&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;assistant
                数据源：&quot;&quot;&quot;
            
            # 调用LLM
            route_name = self._call_llm(prompt, max_tokens=50)
            route_name = route_name.strip().lower()
            
            print(f&quot;🤖 LLM路由决策: &apos;{route_name}&apos;&quot;)
            
            # 检查路由是否存在
            if route_name not in self.routes:
                for name in self.routes.keys():
                    if name in route_name or route_name in name:
                        route_name = name
                        print(f&quot;🔄 模糊匹配到: {route_name}&quot;)
                        break
                else:
                    raise ValueError(f&quot;未找到路由: {route_name}&quot;)
            
            return route_name, self.routes[route_name]
            
        except Exception as e:
            print(f&quot;⚠️ 路由失败: {e}, 使用回退路由: {self.fallback_route}&quot;)
            return self.fallback_route, self.routes[self.fallback_route]
    
    def _call_llm(self, prompt: str, max_tokens: int = 100) -&gt; str:
        &quot;&quot;&quot;调用LLM&quot;&quot;&quot;
        sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            max_tokens=max_tokens,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        
        if outputs and outputs[0].outputs:
            return outputs[0].outputs[0].text.strip()
        return &quot;&quot;
    
    def query_multiple(self, question: str, max_routes: int = 2, k: int = 3) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;查询多个数据源（带回退机制）&quot;&quot;&quot;
        print(f&quot;🎯 多数据源查询: {question}&quot;)
        
        # 获取主路由
        primary_route, primary_retriever = self.route(question)
        print(f&quot;📍 主路由: {primary_route}&quot;)
        
        # 获取主路由文档
        primary_docs = primary_retriever.get_relevant_documents(question, k=k)
        results = {
            primary_route: primary_docs
        }
        print(f&quot;📚 {primary_route}: {len(primary_docs)} 个文档&quot;)
        
        # 如果需要，添加回退路由
        if len(results) &amp;#x3C; max_routes and primary_route != self.fallback_route:
            print(f&quot;🔄 添加回退路由: {self.fallback_route}&quot;)
            fallback_docs = self.routes[self.fallback_route].get_relevant_documents(question, k=k)
            results[self.fallback_route] = fallback_docs
            print(f&quot;📚 {self.fallback_route}: {len(fallback_docs)} 个文档&quot;)
        
        return results
    
    def _get_route_descriptions(self) -&gt; str:
        &quot;&quot;&quot;获取路由描述&quot;&quot;&quot;
        descriptions = []
        for name, retriever in self.routes.items():
            desc = self.route_descriptions.get(name, f&apos;{name}数据源&apos;)
            descriptions.append(f&quot;- {name}: {desc}&quot;)
        return &quot;\n&quot;.join(descriptions)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.6 多数据源查询实战&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建带回退机制的路由器
router = LogicalRouterWithFallback(
    llm=llm,
    embeddings=embeddings,
    fallback_route=&quot;web_content&quot;,
    persist_directory=&quot;./chroma_db&quot;
)

# 多数据源查询示例
query = &quot;论文中抓取检测的定义&quot;
results = router.query_multiple(query, max_routes=2, k=3)

for route_name, docs in results.items():
    print(f&quot;\n{route_name}: {len(docs)} 个文档&quot;)
    for i, doc in enumerate(docs[:2], 1):
        source = doc.metadata.get(&apos;source&apos;, &apos;未知来源&apos;)
        print(f&quot;  {i}. 来源: {source}&quot;)
        clean_content = doc.page_content[:80].replace(&apos;\n&apos;, &apos; &apos;).replace(&apos;\r&apos;, &apos;&apos;).strip()
        print(f&quot;     内容: {clean_content}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🎯 多数据源查询: 论文中抓取检测的定义
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00, 12.62it/s]
🤖 LLM路由决策: &apos;langchain&apos;
📍 主路由: langchain
📚 langchain: 4 个文档
🔄 添加回退路由: web_content
📚 web_content: 4 个文档

langchain: 4 个文档
  1. 来源: ./Dataset/PDF/基于视触觉融合的机器人物体识别和抓取稳定性检测的研究与应用_上官明雨.pdf
     内容: 1.4 本论文组织结构  本文总体框架图如图 1.4 所示。本文将研究内容分为六个章节进行阐述：  第一章，绪论。本章主要介绍了课题的研究背景与研究意义，分析了...
  2. 来源: ./Dataset/PDF/基于视触感知协同的机器人抓取技术研究_祝会龙.pdf
     内容: 西南科技大学硕士学位论文  32    图 4-2 抓取过程   Fig.4-2 Grabbing Process   （1）抓手打开阶段：抓手在初始位置 将抓...

web_content: 4 个文档
  1. 来源: https://blog.csdn.net/WhiffeYF/article/details/111031270
     内容: 械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂抓取位姿估计与轨迹优化研究-CSDN博...
  2. 来源: https://blog.csdn.net/qq_40081208/article/details/111053208
     内容: 抓取检测之Dex-Net 2.0_dexnet-CSDN博客...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.7 逻辑路由的优缺点分析&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;✅ 优点&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;可预测和可控&lt;/strong&gt;：基于预定义规则，行为可预测&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;易于理解和调试&lt;/strong&gt;：规则明确，便于调试&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;适合确定性场景&lt;/strong&gt;：对明确分类的问题效果好&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;快速且高效&lt;/strong&gt;：规则匹配速度快&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;❌ 缺点&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;灵活性有限&lt;/strong&gt;：难以处理复杂或边界情况&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;需要预定义规则&lt;/strong&gt;：需要人工设计路由规则&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;难以处理边界情况&lt;/strong&gt;：模糊查询效果不佳&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;可能需要频繁更新规则&lt;/strong&gt;：随着数据源变化需要更新&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;1.8 性能优化技巧&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class OptimizedLogicalRouter(LogicalRouterWithFallback):
    &quot;&quot;&quot;优化版逻辑路由器&quot;&quot;&quot;
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.route_cache = {}  # 路由缓存
        self.cache_size = 1000  # 缓存大小
    
    def _get_cache_key(self, question: str) -&gt; str:
        &quot;&quot;&quot;生成缓存键&quot;&quot;&quot;
        import hashlib
        return hashlib.md5(question.encode()).hexdigest()
    
    def route(self, question: str) -&gt; tuple:
        &quot;&quot;&quot;带缓存的路由&quot;&quot;&quot;
        cache_key = self._get_cache_key(question)
        
        # 检查缓存
        if cache_key in self.route_cache:
            print(&quot;💾 使用缓存路由&quot;)
            return self.route_cache[cache_key]
        
        # 执行路由
        result = super().route(question)
        
        # 更新缓存
        if len(self.route_cache) &gt;= self.cache_size:
            # 简单LRU策略：移除第一个
            first_key = next(iter(self.route_cache))
            del self.route_cache[first_key]
        
        self.route_cache[cache_key] = result
        return result
    
    def batch_route(self, questions: List[str]) -&gt; List[tuple]:
        &quot;&quot;&quot;批量路由&quot;&quot;&quot;
        results = []
        for question in questions:
            result = self.route(question)
            results.append((question, result))
        return results
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;1.9 总结&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;逻辑路由的核心价值&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;智能数据源选择&lt;/strong&gt;：根据查询意图自动选择最相关的数据源&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;性能优化&lt;/strong&gt;：避免查询所有数据源，提高检索效率&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;结果质量提升&lt;/strong&gt;：确保查询在最适合的数据源中执行&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;系统可扩展性&lt;/strong&gt;：支持轻松添加新的数据源&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;🎯 使用建议&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;| 场景 | 推荐策略 | 说明 |
|------|---------|------|
| &lt;strong&gt;数据源明确&lt;/strong&gt; | 单一路由 | 查询目标明确时使用 |
| &lt;strong&gt;结果全面性&lt;/strong&gt; | 多数据源+回退 | 需要全面覆盖时使用 |
| &lt;strong&gt;性能敏感&lt;/strong&gt; | 带缓存路由 | 高并发场景使用 |
| &lt;strong&gt;数据源变化频繁&lt;/strong&gt; | 动态路由 | 数据源经常变化时使用 |&lt;/p&gt;
&lt;h3&gt;1.10 最佳实践&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;合理设置回退机制&lt;/strong&gt;：确保系统鲁棒性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;使用路由缓存&lt;/strong&gt;：提升高频查询性能&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;监控路由准确性&lt;/strong&gt;：定期评估和优化路由规则&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;支持人工干预&lt;/strong&gt;：提供手动指定数据源的能力&lt;/li&gt;
&lt;/ol&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;经验总结&lt;/strong&gt;：逻辑路由是构建多数据源RAG系统的基础，通过智能的数据源选择，显著提升了系统的检索效率和准确性。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;在下一部分，我们将探讨更高级的&lt;strong&gt;语义路由&lt;/strong&gt;和&lt;strong&gt;查询构建&lt;/strong&gt;技术，进一步提升复杂查询的处理能力。&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;Part 2: 语义路由 - Semantic Routing&lt;/h3&gt;
&lt;p&gt;在Part1中，我们成功构建了一个包含多种数据源的本地知识库。现在面临一个核心挑战：当用户提出不同性质的问题时，系统应该如何智能地判断从哪个数据源中寻找答案？这就是&lt;strong&gt;语义路由（Semantic Routing）&lt;/strong&gt; 要解决的关键问题。&lt;/p&gt;
&lt;h3&gt;2.1 核心概念：从&quot;关键词&quot;到&quot;语义理解&quot;&lt;/h3&gt;
&lt;p&gt;传统的路由方式可能依赖于关键词匹配（例如，问题中包含&quot;论文&quot;就路由到学术库）。但这种方式非常僵化，无法处理复杂语义。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;语义路由&lt;/strong&gt;的核心理念是使用文本的&lt;strong&gt;嵌入向量（Embeddings）&lt;/strong&gt; 来进行路由决策。它为每个数据源（路由）定义一段&lt;strong&gt;描述性文本&lt;/strong&gt;，并将这段描述也转换为嵌入向量。当用户查询到来时，系统会计算查询的嵌入向量与所有路由描述嵌入向量的&lt;strong&gt;语义相似度&lt;/strong&gt;，然后选择最相似的路由进行检索。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;工作原理可以简化为以下流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询 (例如：&quot;如何实现抓取检测？&quot;)
    ↓
计算查询的嵌入向量
    ↓
计算与各路由描述的相似度
    ↓
选择最相似的路由
    ↓
在该路由对应的向量库中执行检索
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.2 代码实现：构建语义路由器&lt;/h3&gt;
&lt;p&gt;以下是一个完整的&lt;code&gt;SemanticRouter&lt;/code&gt;类实现，它基于LangChain和ChromaDB，完美适配本地环境。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import numpy as np
from typing import Dict, List, Tuple
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

class SemanticRouter:
    &quot;&quot;&quot;语义路由器（适配本地配置）&quot;&quot;&quot;
    
    def __init__(self, embeddings, persist_directory=&quot;./chroma_db&quot;):
        self.embeddings = embeddings
        self.persist_directory = persist_directory
        self.routes = {}
        self.route_embeddings = {}
        
        # 初始化向量集合
        self._init_collections()
    
    def _init_collections(self):
        &quot;&quot;&quot;初始化向量数据库集合&quot;&quot;&quot;
        # 学术论文集合
        self.langchain_collection = Chroma(
            collection_name=&quot;langchain&quot;,
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
        
        # 网页内容集合
        self.web_content_collection = Chroma(
            collection_name=&quot;web_content&quot;,
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
    
    def add_route(self, name: str, description: str, retriever):
        &quot;&quot;&quot;添加路由规则&quot;&quot;&quot;
        self.routes[name] = {
            &apos;description&apos;: description,
            &apos;retriever&apos;: retriever
        }
        
        # 预计算路由描述的嵌入向量
        self.route_embeddings[name] = self.embeddings.embed_query(description)
        print(f&quot;✅ 添加路由: {name} - {description}&quot;)
    
    def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -&gt; float:
        &quot;&quot;&quot;计算余弦相似度&quot;&quot;&quot;
        vec1 = np.array(vec1)
        vec2 = np.array(vec2)
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        return dot_product / (norm1 * norm2) if norm1 and norm2 else 0.0
    
    def route(self, question: str, threshold: float = 0.3) -&gt; Tuple[str, float]:
        &quot;&quot;&quot;执行语义路由&quot;&quot;&quot;
        query_embedding = self.embeddings.embed_query(question)
        
        similarities = {}
        for name, route_embedding in self.route_embeddings.items():
            similarity = self._cosine_similarity(query_embedding, route_embedding)
            similarities[name] = similarity
        
        best_route = max(similarities, key=similarities.get)
        best_score = similarities[best_route]
        
        print(f&quot;🔍 路由分析:&quot;)
        for route, score in sorted(similarities.items(), key=lambda x: x[1], reverse=True):
            print(f&quot;   {route}: {score:.3f}&quot;)
        
        return best_route, best_score
    
    def query(self, question: str, k: int = 4):
        &quot;&quot;&quot;完整查询流程&quot;&quot;&quot;
        print(f&quot;🎯 查询: {question}&quot;)
        
        # 语义路由
        route_name, score = self.route(question)
        print(f&quot;📍 路由到: {route_name} (相似度: {score:.3f})&quot;)
        
        # 检索
        retriever = self.routes[route_name][&apos;retriever&apos;]
        docs = retriever.get_relevant_documents(question, k=k)
        
        return {
            &quot;route&quot;: route_name,
            &quot;score&quot;: score,
            &quot;documents&quot;: docs
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.3 初始化与路由配置&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建语义路由器
router = SemanticRouter(embeddings, persist_directory=&quot;./chroma_db&quot;)

# 配置路由规则
router.add_route(
    name=&quot;langchain&quot;,
    description=&quot;学术论文、技术文档、研究论文、抓取检测和滑动检测相关技术&quot;,
    retriever=router.langchain_collection.as_retriever()
)

router.add_route(
    name=&quot;web_content&quot;, 
    description=&quot;网页内容、技术教程、实践指南、编程教程、CSDN博客、实际操作步骤&quot;,
    retriever=router.web_content_collection.as_retriever()
)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.4 实战测试与结果分析&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;测试1：技术算法类问题&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 测试偏向理论算法的问题
result = router.query(&quot;如何实现一个抓取检测算法？&quot;)

print(&quot;📄 检索结果预览:&quot;)
for i, doc in enumerate(result[&apos;documents&apos;][:2], 1):
    source = doc.metadata.get(&apos;source&apos;, &apos;未知来源&apos;)
    content_preview = doc.page_content[:80].replace(&apos;\n&apos;, &apos; &apos;)
    print(f&quot;  {i}. 来源: {source}&quot;)
    print(f&quot;     内容: {content_preview}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🎯 查询: 如何实现一个抓取检测算法？
🔍 路由分析:
   langchain: 0.501
   web_content: 0.270
📍 路由到: langchain (相似度: 0.501)
📄 检索结果预览:
  1. 来源: ./Dataset/PDF/基于视触感知协同的机器人抓取技术研究_祝会龙.pdf
     内容: 西南科技大学硕士学位论文 32 图 4-2 抓取过程 Fig.4-2 Grabbing Process （1）抓手打开阶段...
  2. 来源: ./Dataset/PDF/基于视触感知协同的机器人抓取技术研究_祝会龙.pdf
     内容: 的采集，使用 Savitzky-Golay 滤波算法进行数据滤波，并进行了测试。然后，研究了 TSF...
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;测试2：实践操作类问题&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 测试偏向实践操作的问题
result = router.query(&quot;如何写一个技术博客？&quot;)

print(&quot;📄 检索结果预览:&quot;)
for i, doc in enumerate(result[&apos;documents&apos;][:2], 1):
    source = doc.metadata.get(&apos;source&apos;, &apos;未知来源&apos;)
    content_preview = doc.page_content[:80].replace(&apos;\n&apos;, &apos; &apos;)
    print(f&quot;  {i}. 来源: {source}&quot;)
    print(f&quot;     内容: {content_preview}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🎯 查询: 如何写一个技术博客？
🔍 路由分析:
   web_content: 0.487
   langchain: 0.401
📍 路由到: web_content (相似度: 0.487)
📄 检索结果预览:
  1. 来源: 链接1
     内容: 机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional...
  2. 来源: 链接2  
     内容: 械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.5 语义路由的技术优势&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;智能语义理解&lt;/strong&gt;：系统能够理解&quot;抓取检测算法&quot;是研究主题（路由到学术库），而&quot;写技术博客&quot;是实践操作（路由到网页库）&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;灵活可扩展&lt;/strong&gt;：新增数据源只需添加路由描述，无需修改核心逻辑&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;决策透明化&lt;/strong&gt;：路由得分可视化，便于调试和优化描述文本&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;检索效率提升&lt;/strong&gt;：避免全库搜索，针对性检索提升响应速度&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;2.6 多路查询与结果融合&lt;/h2&gt;
&lt;p&gt;除了单一路由，系统还支持&lt;strong&gt;多路并行查询&lt;/strong&gt;，对于复杂问题可以从多个角度获取信息。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def query_multiple(self, question: str, k: int = 3):
    &quot;&quot;&quot;多数据源并行查询&quot;&quot;&quot;
    scores = self.route_with_scores(question)
    
    results = {}
    for route_name, score in scores.items():
        if score &gt; 0.2:  # 相似度阈值
            retriever = self.routes[route_name][&apos;retriever&apos;]
            docs = retriever.get_relevant_documents(question, k=k)
            results[route_name] = {
                &apos;documents&apos;: docs,
                &apos;score&apos;: score
            }
    
    return results

# 使用示例
multi_result = router.query_multiple(&quot;抓取检测的最新进展&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;2.7 混合路由：智能选择最佳检索策略&lt;/h3&gt;
&lt;p&gt;在前面的章节中，我们分别介绍了逻辑路由和语义路由。这两种路由策略各有优劣，为了在实际应用中达到最佳效果，我们需要一个能够&lt;strong&gt;智能选择路由策略&lt;/strong&gt;的系统，这就是&lt;strong&gt;混合路由（Hybrid Routing）&lt;/strong&gt;。&lt;/p&gt;
&lt;h4&gt;2.7.1 混合路由的核心思想&lt;/h4&gt;
&lt;p&gt;混合路由的核心是根据查询的&lt;strong&gt;复杂度&lt;/strong&gt;和&lt;strong&gt;特性&lt;/strong&gt;，动态选择最适合的路由策略：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;简单查询&lt;/strong&gt; → 使用&lt;strong&gt;语义路由&lt;/strong&gt;（速度快、成本低）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;复杂查询&lt;/strong&gt; → 使用&lt;strong&gt;逻辑路由&lt;/strong&gt;（准确性高、可解释性强）&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;工作流程示意：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询: &quot;如何学习抓取检测算法？&quot;
    ↓
自适应路由分析
    ↓
复杂度判断: simple (简单查询)
    ↓
选择语义路由策略
    ↓
语义路由流程开始
    ↓
计算查询嵌入
    ↓
计算与各路由描述的相似度
    ↓
选择最相似的路由: langchain (分数: 0.45)
    ↓
执行检索: 从langchain集合检索文档
    ↓
返回结果: 4个相关文档
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;2.7.2 混合路由器的实现&lt;/h4&gt;
&lt;p&gt;以下是&lt;code&gt;HybridRouter&lt;/code&gt;类的完整实现，它集成了前面实现的语义路由器和逻辑路由器：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from typing import Dict, List, Any
import numpy as np

class HybridRouter:
    &quot;&quot;&quot;混合路由器：结合逻辑和语义路由（使用现有实现）&quot;&quot;&quot;
    
    def __init__(self, embeddings, llm, persist_directory=&quot;./chroma_db&quot;):
        self.embeddings = embeddings
        self.llm = llm
        self.persist_directory = persist_directory
        self.routes = {}
        
        # 使用之前实现的语义路由器和逻辑路由器
        self.semantic_router = SemanticRouter(embeddings, persist_directory)
        self.logical_router = LogicalRouter(llm, embeddings, persist_directory)
    
    def add_route(self, name: str, description: str, retriever):
        &quot;&quot;&quot;添加路由到两个路由器&quot;&quot;&quot;
        # 添加到语义路由器
        self.semantic_router.add_route(name, description, retriever)
        # 添加到逻辑路由器
        self.logical_router.add_route(name, retriever)
        # 添加到本地路由表
        self.routes[name] = retriever
        print(f&quot;✅ 添加混合路由: {name} - {description}&quot;)
    
    def route(self, question: str, use_semantic: bool = True, semantic_threshold: float = 0.3):
        &quot;&quot;&quot;执行混合路由
        
        Args:
            question: 用户查询
            use_semantic: 是否优先使用语义路由
            semantic_threshold: 语义路由阈值（本地模型阈值较低）
        &quot;&quot;&quot;
        print(f&quot;🎯 混合路由查询: {question}&quot;)
        
        if use_semantic:
            # 尝试语义路由
            try:
                route_name, score = self.semantic_router.route(question, threshold=semantic_threshold)
                
                if score &gt;= semantic_threshold:
                    print(f&quot;✅ 使用语义路由: {route_name} (分数: {score:.3f})&quot;)
                    return route_name, self.routes[route_name]
                else:
                    print(f&quot;⚠️ 语义路由分数过低 ({score:.3f} &amp;#x3C; {semantic_threshold})，切换到逻辑路由&quot;)
            except Exception as e:
                print(f&quot;⚠️ 语义路由失败: {e}，切换到逻辑路由&quot;)
        
        # 使用逻辑路由作为后备
        try:
            route_name, retriever = self.logical_router.route(question)
            print(f&quot;✅ 使用逻辑路由: {route_name}&quot;)
            return route_name, retriever
        except Exception as e:
            print(f&quot;❌ 逻辑路由失败: {e}，使用默认路由&quot;)
            # 回退到web_content
            return &quot;web_content&quot;, self.routes[&quot;web_content&quot;]
    
    def query(self, question: str, use_semantic: bool = True, k: int = 4):
        &quot;&quot;&quot;执行查询&quot;&quot;&quot;
        route_name, retriever = self.route(question, use_semantic)
        docs = retriever.get_relevant_documents(question, k=k)
        
        print(f&quot;📚 检索到 {len(docs)} 个文档&quot;)
        
        return {
            &quot;route&quot;: route_name,
            &quot;documents&quot;: docs,
            &quot;question&quot;: question
        }
    
    def query_adaptive(self, question: str, k: int = 4):
        &quot;&quot;&quot;自适应查询：根据查询复杂度选择最佳路由&quot;&quot;&quot;
        print(f&quot;🎯 自适应查询: {question}&quot;)
        
        # 分析查询复杂度
        complexity = self._analyze_complexity(question)
        print(f&quot;📊 查询复杂度: {complexity}&quot;)
        
        if complexity == &quot;simple&quot;:
            # 简单查询：使用语义路由
            return self.query(question, use_semantic=True, k=k)
        else:
            # 复杂查询：使用逻辑路由
            return self.query(question, use_semantic=False, k=k)
    
    def _analyze_complexity(self, question: str) -&gt; str:
        &quot;&quot;&quot;分析查询复杂度&quot;&quot;&quot;
        # 简单规则：根据查询长度和关键词判断
        if len(question) &amp;#x3C; 20 and any(keyword in question.lower() for keyword in [&quot;是什么&quot;, &quot;怎么用&quot;, &quot;如何&quot;, &quot;教程&quot;]):
            return &quot;simple&quot;
        else:
            return &quot;complex&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;2.7.3 初始化混合路由器&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建混合路由器
hybrid_router = HybridRouter(embeddings, llm, persist_directory=&quot;./chroma_db&quot;)

# 添加两个路由
hybrid_router.add_route(
    name=&quot;langchain&quot;,
    description=&quot;学术论文、技术文档、研究论文、抓取检测和滑动检测相关技术&quot;,
    retriever=hybrid_router.semantic_router.langchain_collection.as_retriever()
)

hybrid_router.add_route(
    name=&quot;web_content&quot;, 
    description=&quot;网页内容、技术教程、实践指南、编程教程、CSDN博客、实际操作步骤&quot;,
    retriever=hybrid_router.semantic_router.web_content_collection.as_retriever()
)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出示例：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;✅ 添加路由: langchain - 学术论文、技术文档、研究论文、抓取检测和滑动检测相关技术
✅ 添加混合路由: langchain - 学术论文、技术文档、研究论文、抓取检测和滑动检测相关技术
✅ 添加路由: web_content - 网页内容、技术教程、实践指南、编程教程、CSDN博客、实际操作步骤
✅ 添加混合路由: web_content - 网页内容、技术教程、实践指南、编程教程、CSDN博客、实际操作步骤
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;2.7.4 自适应路由测试&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;测试1：简单查询（自动选择语义路由）&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;query = &quot;如何学习抓取检测算法？&quot;
print(&quot;\n🎯 自适应路由测试:&quot;)
result = hybrid_router.query_adaptive(query)
print(f&quot;   路由结果: {result[&apos;route&apos;]}&quot;)
print(f&quot;   文档数量: {len(result[&apos;documents&apos;])}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🎯 自适应查询: 如何学习抓取检测算法？
📊 查询复杂度: simple
🎯 混合路由查询: 如何学习抓取检测算法？
🔍 路由分析:
   langchain: 0.497
   web_content: 0.305
✅ 使用语义路由: langchain (分数: 0.497)
📚 检索到 4 个文档
   路由结果: langchain
   文档数量: 4
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;测试2：复杂查询（自动选择逻辑路由）&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;complex_query = &quot;请详细分析抓取检测算法在工业机器人中的应用场景、技术挑战和未来发展趋势&quot;
print(&quot;\n🎯 复杂查询自适应路由测试:&quot;)
result = hybrid_router.query_adaptive(complex_query)
print(f&quot;   路由结果: {result[&apos;route&apos;]}&quot;)
print(f&quot;   文档数量: {len(result[&apos;documents&apos;])}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🎯 自适应查询: 请详细分析抓取检测算法在工业机器人中的应用场景、技术挑战和未来发展趋势
📊 查询复杂度: complex
🎯 混合路由查询: 请详细分析抓取检测算法在工业机器人中的应用场景、技术挑战和未来发展趋势
🤖 LLM分析: 这是一个复杂的综合分析请求，涉及应用场景、技术挑战和未来趋势...
✅ 使用逻辑路由: langchain
📚 检索到 4 个文档
   路由结果: langchain
   文档数量: 4
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;2.7.5 路由策略对比分析&lt;/h4&gt;
&lt;p&gt;下表详细对比了两种路由策略的特性：&lt;/p&gt;
&lt;h2&gt;| 特性 | 逻辑路由 | 语义路由 |
|------|----------|----------|
| &lt;strong&gt;决策依据&lt;/strong&gt; | 规则/LLM分类 | 嵌入相似度 |
| &lt;strong&gt;灵活性&lt;/strong&gt; | 中等 | 高 |
| &lt;strong&gt;准确性&lt;/strong&gt; | 高（规则明确时） | 中高 |
| &lt;strong&gt;速度&lt;/strong&gt; | 快 | 很快 ⚡ |
| &lt;strong&gt;成本&lt;/strong&gt; | 需LLM调用 | 仅需嵌入 |
| &lt;strong&gt;可解释性&lt;/strong&gt; | 高 | 中 |
| &lt;strong&gt;适用场景&lt;/strong&gt; | 复杂查询、多条件查询 | 简单查询、相似性查询 |
| &lt;strong&gt;维护成本&lt;/strong&gt; | 中（需维护规则） | 低 |&lt;/h2&gt;
&lt;h2&gt;Part 3: 查询构建 - Query Construction&lt;/h2&gt;
&lt;p&gt;在前面，我们构建了智能的混合路由系统，能够根据查询特性选择最佳的数据源。现在，我们将深入探讨如何让检索系统&lt;strong&gt;理解更复杂的查询意图&lt;/strong&gt;，这就是&lt;strong&gt;查询构建（Query Construction）&lt;/strong&gt; 要解决的核心问题。&lt;/p&gt;
&lt;h3&gt;3.1 为什么需要查询构建？&lt;/h3&gt;
&lt;p&gt;在实际应用中，用户的查询往往不仅包含对内容本身的语义描述，还包含对文档属性的明确要求。让我们通过一个具体例子来理解：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;场景：带元数据的文档检索&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;假设我们的文档库包含丰富的元数据：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;document = {
    &quot;content&quot;: &quot;深度学习入门教程：详细讲解了神经网络的基础概念...&quot;,
    &quot;metadata&quot;: {
        &quot;author&quot;: &quot;张三&quot;,
        &quot;date&quot;: &quot;2023-06-15&quot;, 
        &quot;category&quot;: &quot;机器学习&quot;,
        &quot;tags&quot;: [&quot;深度学习&quot;, &quot;神经网络&quot;],
        &quot;views&quot;: 1500
    }
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;用户提出复杂查询：&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;&quot;找出张三在2023年写的关于深度学习的文章&quot;&lt;/strong&gt;&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;这个查询包含两种需求：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;语义搜索需求&lt;/strong&gt;：内容需要关于 &lt;code&gt;&quot;深度学习&quot;&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;结构化过滤需求&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;author&lt;/code&gt; 等于 &lt;code&gt;&quot;张三&quot;&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;code&gt;date&lt;/code&gt; 在 &lt;code&gt;&quot;2023-01-01&quot;&lt;/code&gt; 到 &lt;code&gt;&quot;2023-12-31&quot;&lt;/code&gt; 之间&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;查询构建的作用&lt;/strong&gt;就是自动解析这种复杂意图，生成结合语义搜索和精确过滤的复合查询。&lt;/p&gt;
&lt;h3&gt;3.2 自查询检索器原理&lt;/h3&gt;
&lt;p&gt;自查询检索器（Self-Query Retriever）使用LLM来解析自然语言查询，将其转换为结构化查询条件：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询 → LLM解析 → 过滤条件 + 搜索词 → 向量数据库检索
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.3 实现自定义自查询检索器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.chains.query_constructor.base import AttributeInfo
from vllm import SamplingParams
import re

class CustomSelfQueryRetriever:
    &quot;&quot;&quot;自定义自查询检索器（适配vLLM配置）&quot;&quot;&quot;
    
    def __init__(self, llm, vectorstore, metadata_field_info, document_content_description):
        self.llm = llm
        self.vectorstore = vectorstore
        self.metadata_field_info = metadata_field_info
        self.document_content_description = document_content_description
    
    def _call_llm(self, prompt: str, max_tokens: int = 200) -&gt; str:
        &quot;&quot;&quot;调用vLLM模型进行查询解析&quot;&quot;&quot;
        sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            max_tokens=max_tokens,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs and outputs[0].outputs:
            return outputs[0].outputs[0].text.strip()
        return &quot;&quot;
    
    def _parse_filter_query(self, query: str) -&gt; dict:
        &quot;&quot;&quot;解析自然语言查询，生成过滤条件&quot;&quot;&quot;
        # 构建元数据字段描述
        metadata_fields = &quot;\n&quot;.join([
            f&quot;- {field.name}: {field.description} (类型: {field.type})&quot; 
            for field in self.metadata_field_info
        ])
        
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
					你是一个查询解析器，负责将自然语言查询转换为结构化过滤条件。
					
					可用的元数据字段：
					{metadata_fields}
					
					文档内容描述：{self.document_content_description}
					
					请将用户查询解析为过滤条件，格式为：
					filter_type:field_name:value
					
					支持的过滤操作：
					- eq: 等于
					- gt: 大于  
					- lt: 小于
					- contains: 包含（用于内容语义搜索）
					
					示例：
					查询: &quot;找出2023年关于Python的文章&quot;
					解析: eq:date:2023;contains:content:Python
					
					查询: &quot;阅读量超过1000的技术文章&quot;  
					解析: gt:views:1000;contains:content:技术
					
					只返回解析后的过滤条件，不要其他内容。&amp;#x3C;|im_end|&gt;
					&amp;#x3C;|im_start|&gt;user
					查询：{query}&amp;#x3C;|im_end|&gt;
					&amp;#x3C;|im_start|&gt;assistant
					解析：&quot;&quot;&quot;
        
        response = self._call_llm(prompt)
        print(f&quot;🤖 LLM解析结果: {response}&quot;)
        return self._parse_filter_response(response)
    
    def _parse_filter_response(self, response: str) -&gt; dict:
        &quot;&quot;&quot;解析LLM返回的过滤条件&quot;&quot;&quot;
        filters = {}
        
        # 解析格式: eq:date:2023;contains:content:Python
        filter_parts = response.split(&apos;;&apos;)
        
        for part in filter_parts:
            part = part.strip()
            if part.count(&apos;:&apos;) &gt;= 2:
                try:
                    # 拆分成 [操作符, 字段名, 值]
                    op, field, value = part.split(&apos;:&apos;, 2)
                    if field not in filters:
                        filters[field] = {}
                    filters[field][op] = value
                except ValueError:
                    continue
        
        return filters
    
    def get_relevant_documents(self, query: str, k: int = 4):
        &quot;&quot;&quot;获取相关文档（核心方法）&quot;&quot;&quot;
        print(f&quot;🎯 自查询: {query}&quot;)
        
        try:
            # 1. 解析查询生成过滤条件
            filters = self._parse_filter_query(query)
            print(f&quot;🔍 解析的过滤条件: {filters}&quot;)
            
            # 2. 构建过滤条件
            where_clauses = []
            search_query = query  # 默认搜索词
            
            for field, conditions in filters.items():
                for op, value in conditions.items():
                    if op == &apos;eq&apos;:
                        where_clauses.append({field: {&quot;$eq&quot;: value}})
                    elif op == &apos;gt&apos;:
                        where_clauses.append({field: {&quot;$gt&quot;: int(value)}})
                    elif op == &apos;lt&apos;:
                        where_clauses.append({field: {&quot;$lt&quot;: int(value)}})
                    elif op == &apos;contains&apos; and field == &apos;content&apos;:
                        # 使用解析出的内容作为搜索词
                        search_query = value
            
            # 3. 执行检索
            if where_clauses:
                # 组合过滤条件
                if len(where_clauses) == 1:
                    combined_filter = where_clauses[0]
                else:
                    combined_filter = {&quot;$and&quot;: where_clauses}
                
                docs = self.vectorstore.similarity_search(
                    search_query, 
                    k=k, 
                    filter=combined_filter
                )
            else:
                # 无过滤条件，直接搜索
                docs = self.vectorstore.similarity_search(search_query, k=k)
            
            return docs
            
        except Exception as e:
            print(f&quot;❌ 自查询失败: {e}，使用普通检索&quot;)
            # 回退到普通检索
            return self.vectorstore.similarity_search(query, k=k)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.4 配置与初始化&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 定义元数据字段信息
metadata_field_info = [
    AttributeInfo(
        name=&quot;author&quot;,
        description=&quot;文档作者&quot;,
        type=&quot;string&quot;
    ),
    AttributeInfo(
        name=&quot;date&quot;, 
        description=&quot;发布日期，格式: YYYY-MM-DD&quot;,
        type=&quot;string&quot;
    ),
    AttributeInfo(
        name=&quot;category&quot;,
        description=&quot;文档类别，如: Python, 机器学习, Web开发等&quot;,
        type=&quot;string&quot;
    ),
    AttributeInfo(
        name=&quot;views&quot;,
        description=&quot;浏览次数&quot;,
        type=&quot;integer&quot;
    ),
    AttributeInfo(
        name=&quot;source&quot;,
        description=&quot;文档来源&quot;,
        type=&quot;string&quot;
    ),
]

document_content_description = &quot;技术文章、教程和研究论文&quot;

# 创建自定义自查询检索器
self_query_retriever = CustomSelfQueryRetriever(
    llm=llm,
    vectorstore=vectorstore,  # 您的ChromaDB实例
    metadata_field_info=metadata_field_info,
    document_content_description=document_content_description
)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.5 实战测试&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;测试1：结合数值过滤的查询&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;query = &quot;找出点赞超过10次关于抓取检测的文章&quot;
print(f&quot;\n🔍 测试查询: {query}&quot;)
results = self_query_retriever.get_relevant_documents(query, k=3)

print(f&quot;✅ 找到 {len(results)} 个文档&quot;)

for i, doc in enumerate(results, 1):
    # 清理内容显示
    clean_content = &apos; &apos;.join(doc.page_content.replace(&apos;\n&apos;, &apos; &apos;).split())[:80]
    print(f&quot;\n{i}. 内容预览: {clean_content}...&quot;)
    print(f&quot;   元数据: { {k: v for k, v in doc.metadata.items() if k in [&apos;author&apos;, &apos;views&apos;, &apos;source&apos;]} }&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🔍 测试查询: 找出点赞超过10次关于抓取检测的文章
🎯 自查询: 找出点赞超过10次关于抓取检测的文章
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  3.69it/s]
🤖 LLM解析结果: gt:views:10;contains:content:抓取检测
🔍 解析的过滤条件: {&apos;views&apos;: {&apos;gt&apos;: &apos;10&apos;}, &apos;content&apos;: {&apos;contains&apos;: &apos;抓取检测&apos;}}
✅ 找到 3 个文档

1. 内容预览: 机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks...
   元数据: {&apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/110829105&apos;, &apos;views&apos;: 15}

2. 内容预览: 械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂抓取位姿估计...
   元数据: {&apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/111031270&apos;, &apos;views&apos;: 11}
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.6 查询构建的技术优势&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;自然语言理解&lt;/strong&gt;：用户可以用最自然的方式表达复杂查询需求&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;精确过滤&lt;/strong&gt;：结合元数据过滤，大幅提升检索准确性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;灵活组合&lt;/strong&gt;：支持多种条件组合（与、或、范围等）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;错误恢复&lt;/strong&gt;：解析失败时自动降级到普通检索&lt;/li&gt;
&lt;/ol&gt;
&lt;h3&gt;3.7 应用场景与局限性&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;适用场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;电商产品搜索（价格范围、品牌、类别）&lt;/li&gt;
&lt;li&gt;文献检索（作者、年份、期刊）&lt;/li&gt;
&lt;li&gt;内容管理（标签、状态、日期）&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;当前局限性：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;LLM解析可能存在误差&lt;/li&gt;
&lt;li&gt;复杂逻辑（OR条件）支持有限&lt;/li&gt;
&lt;li&gt;需要定义清晰的元数据schema&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;h2&gt;Part 4 自查询检索器 - Self-Query Retriever&lt;/h2&gt;
&lt;p&gt;在前面的章节中，我们探讨了逻辑路由、语义路由以及查询构建器，它们都需要我们显式地定义路由规则或查询结构。然而，在更智能的应用中，我们期望系统能自动理解用户的自然语言查询意图，并将其分解为适合检索的组件。这就是自查询检索器的用武之地。&lt;/p&gt;
&lt;h3&gt;4.1 核心概念&lt;/h3&gt;
&lt;p&gt;自查询检索器是LangChain提供的高级工具，它能够自动将自然语言查询分离为语义搜索部分和结构化过滤部分。其工作流程如下：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;自查询处理流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户自然语言查询
    ↓
LLM分析查询意图
    ↓
分离为两部分：
├─ 语义查询内容（用于向量搜索）
└─ 元数据过滤条件（用于结构化过滤）
    ↓
执行混合检索
    ↓
返回精确定位的结果
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.2 实现原理&lt;/h3&gt;
&lt;p&gt;自查询检索器的核心在于利用大语言模型（LLM）的语义理解能力，自动解析用户查询中的隐含过滤条件。以下是简化版实现：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.chains.query_constructor.base import (
    StructuredQueryOutputParser,
    get_query_constructor_prompt,
)
from langchain_core.output_parsers import StrOutputParser
from vllm import SamplingParams
import re

class SimpleSelfQueryRetriever:
    &quot;&quot;&quot;简化版自查询检索器&quot;&quot;&quot;
    
    def __init__(self, llm, vectorstore, metadata_field_info, document_content_description):
        self.llm = llm
        self.vectorstore = vectorstore
        self.metadata_field_info = metadata_field_info
        self.document_content_description = document_content_description
    
    def _call_llm(self, prompt: str, max_tokens: int = 300) -&gt; str:
        &quot;&quot;&quot;调用vLLM模型&quot;&quot;&quot;
        sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            max_tokens=max_tokens,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs and outputs[0].outputs:
            return outputs[0].outputs[0].text.strip()
        return &quot;&quot;
    
    def get_relevant_documents(self, query: str, k: int = 4):
        &quot;&quot;&quot;简化版自查询&quot;&quot;&quot;
        print(f&quot;🎯 查询: {query}&quot;)
        
        # 构建自查询提示词
        metadata_info = &quot;\n&quot;.join([f&quot;- {field.name}: {field.description} (类型: {field.type})&quot; 
                                 for field in self.metadata_field_info])
        
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个查询分析助手。请分析用户查询并提取：
            
            1. 搜索关键词（用于语义搜索）
            2. 过滤条件（基于元数据）
            
            可用的元数据字段：
            {metadata_info}
            
            文档类型：{self.document_content_description}
            
            返回格式：
            关键词：[搜索关键词]
            过滤：[字段名 操作符 值] 或多个条件用逗号分隔
            
            示例：
            查询：&quot;阅读量超过100的技术文章&quot;  
            返回：
            关键词：技术文章
            过滤：views &gt; 100

            查询：&quot;语言为中文的技术文章&quot;  
            返回：
            关键词：技术文章
            过滤：language == zh-CN
            
            只返回格式化的结果，不要解释。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            查询：{query}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        try:
            # 调用LLM解析查询
            response = self._call_llm(prompt, max_tokens=200)
            print(f&quot;🤖 LLM解析结果: {response}&quot;)
            
            # 解析响应
            search_keyword, filter_conditions = self._parse_llm_response(response, query)
            print(f&quot;🔍 搜索关键词: {search_keyword}&quot;)
            print(f&quot;📋 过滤条件: {filter_conditions}&quot;)
            
            # 构建过滤条件
            filter_dict = self._build_filter_dict(filter_conditions)
            
            # 执行检索
            if filter_dict:
                docs = self.vectorstore.similarity_search(
                    search_keyword, 
                    k=k, 
                    filter=filter_dict
                )
            else:
                docs = self.vectorstore.similarity_search(search_keyword, k=k)
            
            print(f&quot;✅ 找到 {len(docs)} 个文档&quot;)
            return docs
            
        except Exception as e:
            print(f&quot;❌ 自查询失败: {e}，使用普通检索&quot;)
            return self.vectorstore.similarity_search(query, k=k)
    
    def _parse_llm_response(self, response: str, original_query: str):
        &quot;&quot;&quot;解析LLM响应&quot;&quot;&quot;
        # 默认值
        search_keyword = original_query
        filter_conditions = []
        
        lines = response.split(&apos;\n&apos;)
        for line in lines:
            if line.startswith(&apos;关键词：&apos;):
                search_keyword = line.replace(&apos;关键词：&apos;, &apos;&apos;).strip()
            elif line.startswith(&apos;过滤：&apos;):
                filters = line.replace(&apos;过滤：&apos;, &apos;&apos;).strip()
                if filters and filters != &apos;无&apos;:
                    filter_conditions = [f.strip() for f in filters.split(&apos;,&apos;)]
        
        return search_keyword, filter_conditions
    
    def _build_filter_dict(self, filter_conditions):
        &quot;&quot;&quot;构建过滤字典&quot;&quot;&quot;
        if not filter_conditions:
            return None
        
        filter_dict = {}
        
        for condition in filter_conditions:
            parts = condition.split()
            if len(parts) &gt;= 3:
                field, op, value = parts[0], parts[1], &apos; &apos;.join(parts[2:])
                
                # 处理值类型
                if value.isdigit():
                    value = int(value)
                
                op_map = {
                    &apos;&gt;&apos;: &apos;$gt&apos;, &apos;&amp;#x3C;&apos;: &apos;$lt&apos;, &apos;&gt;=&apos;: &apos;$gte&apos;, &apos;&amp;#x3C;=&apos;: &apos;$lte&apos;,
                    &apos;=&apos;: &apos;$eq&apos;, &apos;contains&apos;: &apos;$regex&apos;
                }
                
                if op in op_map:
                    if op == &apos;contains&apos;:
                        filter_dict[field] = {op_map[op]: f&quot;.*{value}.*&quot;}
                    else:
                        filter_dict[field] = {op_map[op]: value}
        
        return filter_dict if filter_dict else None

&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.3 实际应用演示&lt;/h3&gt;
&lt;p&gt;以下是自查询检索器的实际测试结果：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 简化版自查询检索器测试
def demo_simple_self_query():
    &quot;&quot;&quot;演示简化版自查询检索器&quot;&quot;&quot;
    
    print(&quot;🚀 简化版自查询检索器测试&quot;)
    print(&quot;=&quot; * 60)
    
    # 4. 定义元数据
    metadata_field_info = [
        AttributeInfo(
            name=&quot;author&quot;,
            description=&quot;文档作者&quot;,
            type=&quot;string&quot;
        ),
        AttributeInfo(
            name=&quot;date&quot;, 
            description=&quot;发布日期，格式: YYYY-MM-DD&quot;,
            type=&quot;string&quot;
        ),
        AttributeInfo(
            name=&quot;category&quot;,
            description=&quot;文档类别，如: Python, 机器学习, Web开发等&quot;,
            type=&quot;string&quot; 
        ),
        AttributeInfo(
            name=&quot;views&quot;,
            description=&quot;浏览次数&quot;,
            type=&quot;integer&quot;
        ),
        AttributeInfo(
            name=&quot;language&quot;,
            description=&quot;语言&quot;,
            type=&quot;string&quot;
        ),
    ]
    
    document_content_description = &quot;技术文章和教程&quot;
    
    # 5. 直接创建简化版检索器
    print(&quot;📝 创建简化版自查询检索器...&quot;)
    retriever = SimpleSelfQueryRetriever(
        llm=llm,
        vectorstore=vectorstore, 
        metadata_field_info=metadata_field_info,
        document_content_description=document_content_description
    )
    
    print(&quot;✅ 简化版检索器创建成功！&quot;)
    
    # 6. 测试不同类型查询
    test_queries = [
        # 时间过滤查询
        {
            &quot;query&quot;: &quot;在CSDN上发布的博客&quot;,
            &quot;description&quot;: &quot;类别过滤&quot;
        },
        # 数值过滤查询
        {
            &quot;query&quot;: &quot;收藏超过100的技术文章&quot;,
            &quot;description&quot;: &quot;数值范围过滤&quot;
        },
        {
            &quot;query&quot;: &quot;语言为中文的技术文章&quot;,
            &quot;description&quot;: &quot;数值范围过滤&quot;
        }
    ]
    
    print(f&quot;\n🧪 开始测试 {len(test_queries)} 个查询...&quot;)
    
    for i, test_case in enumerate(test_queries, 1):
        query = test_case[&quot;query&quot;]
        description = test_case[&quot;description&quot;]
        
        print(f&quot;\n&quot; + &quot;=&quot;*60)
        print(f&quot;📋 测试 {i}/{len(test_queries)}: {description}&quot;)
        print(f&quot;🔍 查询: {query}&quot;)
        
        # 执行查询
        results = retriever.get_relevant_documents(query, k=2)
        
        # 显示结果
        if results:
            print(f&quot;✅ 找到 {len(results)} 个相关文档&quot;)
            for j, doc in enumerate(results, 1):
                # 清理输出
                preview = &apos; &apos;.join(doc.page_content.split())[:80]
                print(f&quot;   {j}. 标题: {preview}...&quot;)
                print(f&quot;      元数据: {doc.metadata}&quot;)
        else:
            print(&quot;❌ 未找到相关文档&quot;)
    
    return retriever


# 运行测试
if __name__ == &quot;__main__&quot;:
    # 测试简化版自查询检索器
    retriever = demo_simple_self_query()
    
    print(&quot;\n&quot; + &quot;=&quot;*60)
    print(&quot;🎉 简化版自查询检索器测试完成！&quot;)
    print(&quot;=&quot;*60)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;测试输出示例：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 简化版自查询检索器测试
============================================================
📝 创建简化版自查询检索器...
✅ 简化版检索器创建成功！

🧪 开始测试 3 个查询...

============================================================
📋 测试 1/3: 类别过滤
🔍 查询: 在CSDN上发布的博客
🎯 查询: 在CSDN上发布的博客
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.94it/s]
🤖 LLM解析结果: 关键词：CSDN
             过滤：author == &quot;CSDN&quot;
🔍 搜索关键词: CSDN
📋 过滤条件: []
✅ 找到 2 个文档
✅ 找到 2 个相关文档
   1. 标题: 抓取检测之Dex-Net 2.0_dexnet-CSDN博客 抓取检测之Dex-Net 2.0 最新推荐文章于 2025-10-11 13:29:58 发布 原...
      元数据: {&apos;title&apos;: &apos;抓取检测之Dex-Net 2.0_dexnet-CSDN博客&apos;, &apos;language&apos;: &apos;zh-CN&apos;, &apos;source&apos;: &apos;https://blog.csdn.net/qq_40081208/article/details/111053208?ops_request_misc=&amp;#x26;request_id=&amp;#x26;biz_id=102&amp;#x26;utm_term=%E6%8A%93%E5%8F%96%E6%A3%80%E6%B5%8B&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-1-111053208.142^v102^pc_search_result_base7&amp;#x26;spm=1018.2226.3001.4187&apos;, &apos;description&apos;: &apos;文章浏览阅读7.3k次，点赞13次，收藏72次。Dex-Net2.0是一种先进的机器人抓取算法，通过两阶段流程实现高效抓取：首先从深度图中采样抓取候选，接着评估抓取质量，最终选出最佳抓取配置。该算法引入了包含670万个样本的大规模数据集，并采用深度学习技术进行抓取质量评估。&apos;}
   2. 标题: 机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks...
      元数据: {&apos;description&apos;: &apos;文章浏览阅读6.2k次，点赞10次，收藏97次。本文提出了一种基于深度学习的实时机器人抓取检测方法，利用卷积神经网络（CNN）进行单阶段回归预测，避免了传统滑动窗口方法的计算成本。模型在康奈尔抓取数据集上实现了88%的准确率，并能在GPU上以13帧/秒的速度运行。相比于现有技术，该模型提升了14个百分点的精度，并且能同时进行物体分类和抓取预测。此外，模型的多抓取版本可以预测单个物体的多个抓取点，显著提高了对多样抓取方式物体的检测性能。&apos;, &apos;title&apos;: &apos;机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks-CSDN博客&apos;, &apos;language&apos;: &apos;zh-CN&apos;, &apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/110829105&apos;}

============================================================
📋 测试 2/3: 数值范围过滤
🔍 查询: 收藏超过100的技术文章
🎯 查询: 收藏超过100的技术文章
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  3.47it/s]
🤖 LLM解析结果: 关键词：技术文章
             过滤：views &gt; 100
🔍 搜索关键词: 技术文章
📋 过滤条件: []
✅ 找到 2 个文档
✅ 找到 2 个相关文档
   1. 标题: 机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks...
      元数据: {&apos;language&apos;: &apos;zh-CN&apos;, &apos;title&apos;: &apos;机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks-CSDN博客&apos;, &apos;description&apos;: &apos;文章浏览阅读6.2k次，点赞10次，收藏97次。本文提出了一种基于深度学习的实时机器人抓取检测方法，利用卷积神经网络（CNN）进行单阶段回归预测，避免了传统滑动窗口方法的计算成本。模型在康奈尔抓取数据集上实现了88%的准确率，并能在GPU上以13帧/秒的速度运行。相比于现有技术，该模型提升了14个百分点的精度，并且能同时进行物体分类和抓取预测。此外，模型的多抓取版本可以预测单个物体的多个抓取点，显著提高了对多样抓取方式物体的检测性能。&apos;, &apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/110829105&apos;}
   2. 标题: 械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂抓取位姿估计与轨迹优化研究-CSDN博客 械臂论文笔记（三）【抓取检测】...
      元数据: {&apos;title&apos;: &apos;械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂抓取位姿估计与轨迹优化研究-CSDN博客&apos;, &apos;description&apos;: &apos;文章浏览阅读9k次，点赞11次，收藏102次。本文综述了机器人抓取检测技术，重点介绍了基于学习的方法，包括基于抓取检测的抓取和基于视觉运动控制策略的端到端抓取，并探讨了各种方法的优势与局限。&apos;, &apos;language&apos;: &apos;zh-CN&apos;, &apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/111031270?ops_request_misc=&amp;#x26;request_id=&amp;#x26;biz_id=102&amp;#x26;utm_term=%E6%8A%93%E5%8F%96%E6%A3%80%E6%B5%8B&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-7-111031270.142^v102^pc_search_result_base7&amp;#x26;spm=1018.2226.3001.4187&apos;}

============================================================
📋 测试 3/3: 数值范围过滤
🔍 查询: 语言为中文的技术文章
🎯 查询: 语言为中文的技术文章
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  3.80it/s]
🤖 LLM解析结果: 关键词：技术文章
             过滤：language == zh-CN
🔍 搜索关键词: 技术文章
📋 过滤条件: []
✅ 找到 2 个文档
✅ 找到 2 个相关文档
   1. 标题: 机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks...
      元数据: {&apos;description&apos;: &apos;文章浏览阅读6.2k次，点赞10次，收藏97次。本文提出了一种基于深度学习的实时机器人抓取检测方法，利用卷积神经网络（CNN）进行单阶段回归预测，避免了传统滑动窗口方法的计算成本。模型在康奈尔抓取数据集上实现了88%的准确率，并能在GPU上以13帧/秒的速度运行。相比于现有技术，该模型提升了14个百分点的精度，并且能同时进行物体分类和抓取预测。此外，模型的多抓取版本可以预测单个物体的多个抓取点，显著提高了对多样抓取方式物体的检测性能。&apos;, &apos;title&apos;: &apos;机械臂论文笔记（二）【实时抓取点检测】Real-Time Grasp Detection Using Convolutional Neural Networks-CSDN博客&apos;, &apos;language&apos;: &apos;zh-CN&apos;, &apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/110829105&apos;}
   2. 标题: 械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂抓取位姿估计与轨迹优化研究-CSDN博客 械臂论文笔记（三）【抓取检测】...
      元数据: {&apos;title&apos;: &apos;械臂论文笔记（三）【抓取检测】机器人抓取检测技术的研究现状 刘亚欣_基于深度图像的机械臂抓取位姿估计与轨迹优化研究-CSDN博客&apos;, &apos;description&apos;: &apos;文章浏览阅读9k次，点赞11次，收藏102次。本文综述了机器人抓取检测技术，重点介绍了基于学习的方法，包括基于抓取检测的抓取和基于视觉运动控制策略的端到端抓取，并探讨了各种方法的优势与局限。&apos;, &apos;source&apos;: &apos;https://blog.csdn.net/WhiffeYF/article/details/111031270?ops_request_misc=&amp;#x26;request_id=&amp;#x26;biz_id=102&amp;#x26;utm_term=%E6%8A%93%E5%8F%96%E6%A3%80%E6%B5%8B&amp;#x26;utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-7-111031270.142^v102^pc_search_result_base7&amp;#x26;spm=1018.2226.3001.4187&apos;, &apos;language&apos;: &apos;zh-CN&apos;}

============================================================
🎉 简化版自查询检索器测试完成！
============================================================
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.4 技术优势与局限&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;优势 ✅:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;自动化解析&lt;/strong&gt;: 自动分离语义查询和结构化过滤条件&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;自然语言友好&lt;/strong&gt;: 支持复杂的自然语言查询意图理解&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;混合检索&lt;/strong&gt;: 结合向量搜索和元数据过滤的最佳效果&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;易于集成&lt;/strong&gt;: 与现有向量数据库无缝集成&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;注意事项 ⚠️:&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;LLM依赖&lt;/strong&gt;: 查询解析质量依赖LLM的语义理解能力&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;元数据定义&lt;/strong&gt;: 需要清晰明确的元数据字段定义&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;性能开销&lt;/strong&gt;: LLM调用增加额外的响应时间&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;错误处理&lt;/strong&gt;: 需要完善的错误回退机制&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;4.5 性能优化策略&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;1. 路由缓存&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from functools import lru_cache
import hashlib

class CachedRouter:
    &quot;&quot;&quot;带缓存的路由器&quot;&quot;&quot;
    
    def __init__(self, semantic_router, cache_size=100):
        self.semantic_router = semantic_router
        self.cache_size = cache_size
        self._cache = {}
    
    def _get_cache_key(self, question: str) -&gt; str:
        &quot;&quot;&quot;生成缓存键&quot;&quot;&quot;
        return hashlib.md5(question.encode()).hexdigest()
    
    @lru_cache(maxsize=100)
    def route(self, question: str):
        &quot;&quot;&quot;带缓存的路由&quot;&quot;&quot;
        cache_key = self._get_cache_key(question)
        
        if cache_key in self._cache:
            print(&quot;💾 使用缓存的路由结果&quot;)
            return self._cache[cache_key]
        
        # 执行路由
        route_name, score = self.semantic_router.route(question)
        
        # 存入缓存
        self._cache[cache_key] = (route_name, score)
        
        return route_name, score
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;2. 批量并行处理&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import asyncio
from typing import List, Tuple

class ParallelRouter:
    &quot;&quot;&quot;并行路由器&quot;&quot;&quot;
    
    async def route_multiple(
        self,
        questions: List[str]
    ) -&gt; List[Tuple[str, str, float]]:
        &quot;&quot;&quot;并行路由多个查询
        
        Returns:
            List[(question, route_name, score)]
        &quot;&quot;&quot;
        async def route_single(question: str):
            # 异步路由（这里简化，实际需要异步嵌入）
            route_name, score = self.semantic_router.route(question)
            return question, route_name, score
        
        tasks = [route_single(q) for q in questions]
        results = await asyncio.gather(*tasks)
        
        return results

# 使用
# results = asyncio.run(router.route_multiple([&quot;问题1&quot;, &quot;问题2&quot;, &quot;问题3&quot;]))
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;p&gt;应用场景选择指南&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;选择自查询检索器当：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;查询条件复杂且包含隐含的过滤需求&lt;/li&gt;
&lt;li&gt;用户习惯使用自然语言表达查询意图&lt;/li&gt;
&lt;li&gt;需要结合语义搜索和精确过滤的混合场景&lt;/li&gt;
&lt;li&gt;系统需要较高的自动化程度&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;选择其他方案当：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;查询规则固定且明确 → 逻辑路由&lt;/li&gt;
&lt;li&gt;只需语义相似性搜索 → 标准检索器&lt;/li&gt;
&lt;li&gt;过滤条件简单明确 → 查询构建器&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;自查询检索器代表了检索系统智能化的高级阶段，通过LLM的语义理解能力，实现了从&quot;如何查询&quot;到&quot;查询什么&quot;的自然过渡，为构建更加智能和用户友好的检索系统提供了有力工具。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>RAG实战（二）RAG查询优化</title><link>https://astro-pure.js.org/blog/rag_blogs/rag_blogs-2</link><guid isPermaLink="true">https://astro-pure.js.org/blog/rag_blogs/rag_blogs-2</guid><description>记录RAG的学习。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/RAG_Learning&quot;&gt;Github地址&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;RAG查询优化：多查询与查询转换技术&lt;/h2&gt;
&lt;p&gt;在基础RAG系统中，我们使用单一查询进行检索。但在实际应用中，用户的查询往往存在&lt;strong&gt;表达模糊、角度单一或过于笼统&lt;/strong&gt;的问题。本章将介绍多种查询优化技术，让你的RAG系统能够更准确地理解用户意图。&lt;/p&gt;
&lt;h3&gt;查询优化的必要性&lt;/h3&gt;
&lt;h4&gt;单一查询的局限性&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;用户查询&lt;/strong&gt;: &quot;机器学习是什么？&quot;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 可能错过的相关文档：
- &quot;深度学习入门&quot; (使用了不同但相关的术语)
- &quot;AI算法基础&quot; (更广泛的主题) 
- &quot;神经网络原理&quot; (具体技术)
- &quot;监督学习vs无监督学习&quot; (细分话题)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;问题分析&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;术语不匹配&lt;/strong&gt;：用户说&quot;机器学习&quot;，文档用&quot;AI算法&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;粒度不一致&lt;/strong&gt;：用户问宏观概念，文档讲具体技术&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;表达差异&lt;/strong&gt;：同一概念有多种表述方式&lt;/li&gt;
&lt;/ul&gt;
&lt;h4&gt;查询优化如何帮助？&lt;/h4&gt;
&lt;p&gt;查询优化的核心思想：&lt;strong&gt;通过生成多个角度的查询、重写查询或分解复杂查询，增加检索到相关文档的概率&lt;/strong&gt;。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# ✅ 查询优化后

原始查询: &quot;机器学习是什么？&quot;

生成的变体:
1. &quot;什么是机器学习算法？&quot;
2. &quot;机器学习的基本概念和原理&quot; 
3. &quot;AI中的机器学习技术&quot;
4. &quot;机器学习的应用场景&quot;

→ 并行检索 → 合并结果 → 去重 → 生成答案
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;技术概览&lt;/h3&gt;
&lt;p&gt;本章将介绍5种主要的查询优化技术：&lt;/p&gt;
&lt;p&gt;| 技术 | 核心思想 | 适用场景 | 复杂度 |
|------|---------|----------|--------|
| &lt;strong&gt;Multi-Query&lt;/strong&gt; | 生成查询的多个变体 | 用户查询表达不清 | ⭐ |
| &lt;strong&gt;RAG-Fusion&lt;/strong&gt; | 多查询+重排序融合 | 需要高质量结果 | ⭐⭐ |
| &lt;strong&gt;Decomposition&lt;/strong&gt; | 分解复杂查询 | 多步骤问题 | ⭐⭐⭐ |
| &lt;strong&gt;Step Back&lt;/strong&gt; | 先问概括性问题 | 需要背景知识 | ⭐⭐ |
| &lt;strong&gt;HyDE&lt;/strong&gt; | 生成假设性文档 | 语义搜索增强 | ⭐⭐⭐ |&lt;/p&gt;
&lt;hr&gt;
&lt;h3&gt;环境准备&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from vllm import LLM

# 1. 加载本地嵌入模型
local_model_path = &quot;./Models/maidalun/bce-embedding-base_v1&quot; 
embeddings = HuggingFaceEmbeddings(
    model_name=local_model_path,
    model_kwargs={&quot;device&quot;: &quot;cuda&quot;},
    encode_kwargs={&quot;normalize_embeddings&quot;: True}
)

# 2. 加载本地向量数据库
vectorstore = Chroma(
    persist_directory=&quot;./chroma_db&quot;,
    embedding_function=embeddings
)

# 3. 加载本地大模型
model_dir = &quot;../Qwen-vllm/Models/Qwen/Qwen-7B-Chat-Int8&quot;
os.environ[&apos;VLLM_USE_MODELSCOPE&apos;] = &apos;True&apos;
llm = LLM(
    model=model_dir,
    tokenizer=model_dir,
    trust_remote_code=True
)

print(&quot;✅ 环境初始化完成&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;✅ 环境初始化完成
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;Part1：Multi-Query - 多角度查询&lt;/h2&gt;
&lt;h3&gt;1.1 核心思想&lt;/h3&gt;
&lt;p&gt;Multi-Query技术通过LLM生成原始查询的多个变体，从不同角度检索文档，&lt;strong&gt;提高检索的召回率&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;工作流程&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询: &quot;什么是Agent？&quot;
    ↓
使用LLM生成变体:
    ├─ &quot;Agent系统的定义是什么？&quot;
    ├─ &quot;AI Agent的核心概念&quot;  
    └─ &quot;什么是自主智能体？&quot;
    ↓
并行检索每个变体
    ↓
合并并去重结果
    ↓
生成最终答案
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.2 完整实现代码&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
import chromadb
from vllm import LLM, SamplingParams
from modelscope import snapshot_download
from langchain_community.embeddings import HuggingFaceEmbeddings
from typing import List, Dict, Any

def get_unique_documents(documents: List[List]) -&gt; List:
    &quot;&quot;&quot;去重文档&quot;&quot;&quot;
    unique_docs = {}
    for doc_list in documents:
        for doc in doc_list:
            content = doc.page_content
            if content not in unique_docs:
                unique_docs[content] = doc
    return list(unique_docs.values())

class MultiQueryRAG:
    &quot;&quot;&quot;完全使用ChatML格式的Multi-Query RAG系统&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm):
        &quot;&quot;&quot;
        初始化Multi-Query RAG
        
        Args:
            vectorstore: 向量数据库
            llm: 语言模型
        &quot;&quot;&quot;
        self.vectorstore = vectorstore
        self.retriever = vectorstore.as_retriever(search_kwargs={&quot;k&quot;: 3})
        self.llm = llm
        print(&quot;✅ Multi-Query RAG系统初始化完成&quot;)
    
    def _generate_query_variants(self, question: str) -&gt; List[str]:
        &quot;&quot;&quot;使用ChatML格式生成查询变体&quot;&quot;&quot;
        query_prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个AI助手，擅长将用户的问题改写成多个语义相同但表达不同的搜索查询。每个查询应简洁、独立，适合用于向量检索。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            原始问题：{question}
            
            请生成3个不同的搜索查询，每行一个，不要编号，不要解释：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=500,
            temperature=0.7,  # 适当提高温度增加多样性
            top_p=0.9,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([query_prompt], sampling_params)
        response = outputs[0].outputs[0].text.strip()
        
        # 解析生成的查询
        queries = []
        lines = response.split(&apos;\n&apos;)
        for line in lines:
            line = line.strip()
            if line and len(line) &gt; 5:
                # 移除可能的编号标记
                if line[0].isdigit() and (line[1] == &apos;.&apos; or line[1] == &apos;、&apos;):
                    query = line[2:].strip()
                else:
                    query = line
                queries.append(query)
        
        # 确保包含原始问题并去重
        all_queries = [question] + queries
        unique_queries = []
        seen = set()
        for q in all_queries:
            if q not in seen:
                seen.add(q)
                unique_queries.append(q)
        
        return unique_queries[:4]  # 最多返回4个查询
    
    def _retrieve_documents(self, queries: List[str]) -&gt; List:
        &quot;&quot;&quot;检索所有查询的文档&quot;&quot;&quot;
        all_docs = []
        for query in queries:
            try:
                docs = self.retriever.get_relevant_documents(query)
                all_docs.append(docs)
                print(f&quot;   🔍 &apos;{query}&apos;: 找到 {len(docs)} 个文档&quot;)
            except Exception as e:
                print(f&quot;❌ 检索失败 &apos;{query}&apos;: {e}&quot;)
                all_docs.append([])
        return all_docs
    
    def _process_documents(self, all_docs: List[List]) -&gt; List:
        &quot;&quot;&quot;处理并去重文档&quot;&quot;&quot;
        unique_docs = get_unique_documents(all_docs)
        return unique_docs[:10]  # 返回top-10
    
    def _clean_context(self, text: str) -&gt; str:
        &quot;&quot;&quot;清理上下文内容&quot;&quot;&quot;
        import re
        if not text or not isinstance(text, str):
            return text
        
        # 清理格式
        cleaned = re.sub(r&apos;(?&amp;#x3C;=[^\s])\n(?=[^\s])&apos;, &apos;&apos;, text)  # 合并错误分割的文字
        cleaned = re.sub(r&apos;\n\s+\n&apos;, &apos;\n\n&apos;, cleaned)  # 压缩空行
        cleaned = re.sub(r&apos;\n\s+&apos;, &apos;\n&apos;, cleaned)  # 清理行内空格
        cleaned = re.sub(r&apos;[ \t]{2,}&apos;, &apos; &apos;, cleaned)  # 清理连续空格
        cleaned = cleaned.strip()  # 清理首尾空格
        
        return cleaned
    
    def _build_answer_prompt(self, question: str, context: str) -&gt; str:
        &quot;&quot;&quot;构建ChatML格式的答案生成提示词&quot;&quot;&quot;
        cleaned_context = self._clean_context(context)
        
        return f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个专业的AI助手。请严格基于以下上下文信息回答问题：
            
            {cleaned_context}
            
            请遵循以下规则：
            1. 只使用上下文中的信息回答
            2. 如果上下文不包含相关信息，请回答&quot;我不知道&quot;
            3. 保持回答准确、简洁
            4. 不要编造信息&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            {question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
    
    def _generate_answer(self, prompt: str, max_tokens: int = 512, temperature: float = 0.3) -&gt; str:
        &quot;&quot;&quot;生成答案&quot;&quot;&quot;
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        return outputs[0].outputs[0].text.strip()
    
    def query(self, question: str, max_tokens: int = 512, temperature: float = 0.3) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;执行Multi-Query RAG查询&quot;&quot;&quot;
        print(f&quot;🎯 原始问题: {question}&quot;)
        print(&quot;-&quot; * 50)
        
        # Step 1: 生成查询变体
        queries = self._generate_query_variants(question)
        print(f&quot;📝 生成了 {len(queries)} 个查询:&quot;)
        for i, q in enumerate(queries, 1):
            print(f&quot;   {i}. {q}&quot;)
        
        # Step 2: 检索文档
        all_docs = self._retrieve_documents(queries)
        total_docs_before_dedup = sum(len(docs) for docs in all_docs)
        print(f&quot;📚 检索到 {total_docs_before_dedup} 个文档（含重复）&quot;)
        
        # Step 3: 去重
        unique_docs = self._process_documents(all_docs)
        print(f&quot;✨ 去重后剩余 {len(unique_docs)} 个文档&quot;)
        
        # Step 4: 构建上下文
        context = &quot;\n\n&quot;.join([doc.page_content for doc in unique_docs])
        
        # Step 5: 生成答案
        prompt = self._build_answer_prompt(question, context)
        answer = self._generate_answer(prompt, max_tokens, temperature)
        
        # 返回结果
        return {
            &quot;question&quot;: question,
            &quot;answer&quot;: answer,
            &quot;queries&quot;: queries,
            &quot;num_docs&quot;: len(unique_docs),
            &quot;total_docs_before_dedup&quot;: total_docs_before_dedup,
            &quot;context_preview&quot;: context[:200] + &quot;...&quot; if len(context) &gt; 200 else context
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.3 实战测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# === 使用示例 ===
if __name__ == &quot;__main__&quot;:
    print(&quot;🚀 ChatML格式的Multi-Query RAG&quot;)
    print(&quot;=&quot; * 60)
    
    # 1. 创建实例
    multi_rag = MultiQueryRAG(vectorstore, llm)
    
    # 2. 执行查询
    result = multi_rag.query(&quot;什么是机械臂？它有什么能力？&quot;)
    
    # 3. 查看结果
    print(&quot;\n&quot; + &quot;=&quot;*60)
    print(&quot;💡 答案:&quot;, result[&quot;answer&quot;])
    print(&quot;📝 使用的查询:&quot;, result[&quot;queries&quot;])
    print(&quot;📚 参考文档数:&quot;, result[&quot;num_docs&quot;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 ChatML格式的Multi-Query RAG
============================================================
✅ Multi-Query RAG系统初始化完成
🎯 原始问题: 什么是机械臂？它有什么能力？
--------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.55it/s]
📝 生成了 4 个查询:
   1. 什么是机械臂？它有什么能力？
   2. 机械臂是什么？
   3. 机械臂有哪些功能？
   4. 机械臂的工作原理是什么？
   🔍 &apos;什么是机械臂？它有什么能力？&apos;: 找到 3 个文档
   🔍 &apos;机械臂是什么？&apos;: 找到 3 个文档
   🔍 &apos;机械臂有哪些功能？&apos;: 找到 3 个文档
   🔍 &apos;机械臂的工作原理是什么？&apos;: 找到 3 个文档
📚 检索到 12 个文档（含重复）
✨ 去重后剩余 6 个文档
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.24it/s]

============================================================
💡 答案: 机械臂是一种能够自动执行任务的机器人手臂，它由多个关节组成，可以按照预设的程序或指令进行精确的运动。机械臂具有强大的力量和精确度，可以完成各种复杂的任务，例如抓取、搬运、装配等。
📝 使用的查询: [&apos;什么是机械臂？它有什么能力？&apos;, &apos;机械臂是什么？&apos;, &apos;机械臂有哪些功能？&apos;, &apos;机械臂的工作原理是什么？&apos;]
📚 参考文档数: 6
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.4 Multi-Query的优缺点分析&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;✅ 优点&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;提高召回率&lt;/strong&gt;：找到更多相关文档&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;覆盖不同角度&lt;/strong&gt;：应对表达方式差异&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;对模糊查询有效&lt;/strong&gt;：用户表达不清时特别有用&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实现相对简单&lt;/strong&gt;：逻辑清晰，易于调试&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;❌ 缺点&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;增加检索成本&lt;/strong&gt;：多次查询消耗更多资源&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;可能引入噪声&lt;/strong&gt;：不相关的变体影响质量&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;需要额外LLM调用&lt;/strong&gt;：生成变体增加延迟&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;去重可能过滤有价值文档&lt;/strong&gt;：相似但不同的内容被误删&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;1.5 优化技巧&lt;/h3&gt;
&lt;h4&gt;1. 限制查询数量&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def _generate_query_variants(self, question: str) -&gt; List[str]:
    # ... 生成逻辑 ...
    return unique_queries[:4]  # 最多返回4个查询，平衡效果与成本
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;2. 使用缓存避免重复生成&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from functools import lru_cache

@lru_cache(maxsize=100)
def cached_query_generation(question: str) -&gt; tuple:
    &quot;&quot;&quot;缓存查询生成结果&quot;&quot;&quot;
    queries = query_generator.invoke({&quot;question&quot;: question})
    return tuple(queries)  # 返回元组以支持缓存
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;3. 异步并行检索&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import asyncio

async def async_retrieve_all(queries: List[str], retriever):
    &quot;&quot;&quot;异步并行检索，大幅提升速度&quot;&quot;&quot;
    tasks = [retriever.aget_relevant_documents(q) for q in queries]
    results = await asyncio.gather(*tasks, return_exceptions=True)
    return results
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;4. 智能过滤低质量变体&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def filter_low_quality_queries(original: str, variants: List[str]) -&gt; List[str]:
    &quot;&quot;&quot;过滤与原始查询差异过大的变体&quot;&quot;&quot;
    filtered = []
    for variant in variants:
        # 计算语义相似度，保留相关变体
        similarity = calculate_similarity(original, variant)
        if similarity &gt; 0.6:  # 相似度阈值
            filtered.append(variant)
    return filtered
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;总结&lt;/h3&gt;
&lt;p&gt;Multi-Query技术通过&lt;strong&gt;查询多样性&lt;/strong&gt;有效解决了单一查询的局限性，在实际应用中：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;🎯 适用场景&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;用户查询表达模糊或不专业时&lt;/li&gt;
&lt;li&gt;需要从多角度理解复杂概念时&lt;/li&gt;
&lt;li&gt;文档库中使用术语不统一时&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;⚡ 使用建议&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;简单查询&lt;/strong&gt;：直接使用单一查询&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;复杂/模糊查询&lt;/strong&gt;：启用Multi-Query&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;性能敏感&lt;/strong&gt;：结合缓存和异步优化&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;经验法则&lt;/strong&gt;：当不确定用户查询的精确表述时，Multi-Query是提高召回率的有效策略。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;在下一部分，我们将探讨更高级的&lt;strong&gt;RAG-Fusion技术&lt;/strong&gt;，它结合了多查询和重排序，能够进一步提升检索质量。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part2：RAG-Fusion - 融合式检索&lt;/h2&gt;
&lt;h3&gt;2.1 核心概念&lt;/h3&gt;
&lt;p&gt;RAG-Fusion结合了Multi-Query和&lt;strong&gt;倒数排序融合（Reciprocal Rank Fusion, RRF）&lt;/strong&gt;，不仅生成多个查询，还智能地合并和重排序结果。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;什么是倒数排序融合？&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;RRF是一种排序融合算法，给予排名靠前的文档更高的分数：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;公式&lt;/strong&gt;: &lt;code&gt;RRF_score(doc) = Σ 1 / (k + rank(doc))&lt;/code&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;其中&lt;/strong&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;k&lt;/code&gt; = 常数（通常为60）&lt;/li&gt;
&lt;li&gt;&lt;code&gt;rank(doc)&lt;/code&gt; = 文档在某个查询结果中的排名&lt;/li&gt;
&lt;li&gt;&lt;code&gt;Σ&lt;/code&gt; = 对所有查询结果求和&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;示例演示&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 查询1的结果排序:
查询1: &quot;什么是Agent?&quot;
  1. 文档A (排名1)
  2. 文档B (排名2) 
  3. 文档C (排名3)

# 查询2的结果排序:
查询2: &quot;AI Agent的定义&quot;
  1. 文档B (排名1)  ← 在这个查询中排名更高
  2. 文档D (排名2)
  3. 文档A (排名3)

# RRF融合计算 (k=60):
文档A: 1/(60+1) + 1/(60+3) = 0.0164 + 0.0159 = 0.0323
文档B: 1/(60+2) + 1/(60+1) = 0.0161 + 0.0164 = 0.0325 ← 最高分
文档C: 1/(60+3) + 0         = 0.0159
文档D: 0         + 1/(60+2) = 0.0161

最终排序: B &gt; A &gt; D &gt; C
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.2 RRF算法实现&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from typing import List, Tuple
from vllm import SamplingParams

def reciprocal_rank_fusion(
    results: List[List],
    k: int = 60
) -&gt; List[Tuple[any, float]]:
    &quot;&quot;&quot;
    实现倒数排序融合
    
    Args:
        results: 多个查询的结果列表
        k: RRF常数
    
    Returns:
        排序后的(文档, 分数)列表
    &quot;&quot;&quot;
    # 存储每个文档的融合分数
    fusion_scores = {}
    
    for docs in results:
        for rank, doc in enumerate(docs):
            # 使用文档内容作为唯一标识
            doc_id = doc.page_content
            
            # 计算RRF分数
            if doc_id not in fusion_scores:
                fusion_scores[doc_id] = {
                    &apos;doc&apos;: doc,
                    &apos;score&apos;: 0
                }
            
            # 累加分数
            fusion_scores[doc_id][&apos;score&apos;] += 1 / (k + rank + 1)
    
    # 按分数排序
    sorted_docs = sorted(
        fusion_scores.values(),
        key=lambda x: x[&apos;score&apos;],
        reverse=True
    )
    
    return [(item[&apos;doc&apos;], item[&apos;score&apos;]) for item in sorted_docs]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.3 完整RAG-Fusion系统&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class RAGFusion:
    &quot;&quot;&quot;使用ChatML格式的RAG-Fusion系统&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm, k: int = 60):
        self.vectorstore = vectorstore
        self.retriever = vectorstore.as_retriever(search_kwargs={&quot;k&quot;: 5})
        self.llm = llm
        self.k = k
        
    def generate_queries(self, question: str, n: int = 3) -&gt; List[str]:
        &quot;&quot;&quot;使用ChatML格式生成查询变体&quot;&quot;&quot;
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个AI助手，擅长将用户的问题改写成多个语义相同但表达不同的搜索查询。
            请生成{n}个不同角度的查询变体。
            
            要求：
            1. 语义相关但表达不同
            2. 覆盖不同角度
            3. 每行一个查询&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            原始问题：{question}
            
            请生成{n}个不同的搜索查询，每行一个：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=512,
            temperature=0.7,
            top_p=0.9,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        response = outputs[0].outputs[0].text.strip()
        
        # 解析生成的查询
        queries = []
        lines = response.split(&apos;\n&apos;)
        for line in lines:
            line = line.strip()
            if line and len(line) &gt; 5:  # 过滤太短的查询
                # 移除可能的编号标记
                if line[0].isdigit() and (line[1] == &apos;.&apos; or line[1] == &apos;、&apos;):
                    query = line[2:].strip()
                else:
                    query = line
                queries.append(query)
        
        # 确保包含原始问题并去重
        all_queries = [question] + queries
        unique_queries = []
        seen = set()
        for q in all_queries:
            if q not in seen:
                seen.add(q)
                unique_queries.append(q)
        
        return unique_queries[:n+1]  # 包含原始问题
    
    def retrieve_and_fuse(self, queries: List[str]) -&gt; List[Tuple[any, float]]:
        &quot;&quot;&quot;检索并融合结果&quot;&quot;&quot;
        all_results = []
        
        print(f&quot;🔍 执行 {len(queries)} 个查询...&quot;)
        for i, query in enumerate(queries, 1):
            docs = self.retriever.get_relevant_documents(query)
            all_results.append(docs)
            print(f&quot;   查询{i}: &apos;{query}&apos; -&gt; 检索到 {len(docs)} 个文档&quot;)
        
        # RRF融合
        fused_results = reciprocal_rank_fusion(all_results, k=self.k)
        print(f&quot;✨ 融合后共 {len(fused_results)} 个唯一文档&quot;)
        
        return fused_results
    
    def _build_answer_prompt(self, question: str, context: str) -&gt; str:
        &quot;&quot;&quot;构建ChatML格式的答案生成提示词&quot;&quot;&quot;
        return f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个专业的AI助手。请基于以下按相关性排序的文档回答问题。
            
            {context}
            
            请遵循以下规则：
            1. 只使用文档中的信息回答
            2. 如果文档不包含相关信息，请回答&quot;我不知道&quot;
            3. 保持回答准确、简洁&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            {question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
    
    def _generate_answer(self, prompt: str, max_tokens: int = 512, temperature: float = 0.3) -&gt; str:
        &quot;&quot;&quot;生成答案&quot;&quot;&quot;
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        return outputs[0].outputs[0].text.strip()
    
    def query(self, question: str, top_k: int = 5, max_tokens: int = 512, temperature: float = 0.3) -&gt; dict:
        &quot;&quot;&quot;执行RAG-Fusion查询&quot;&quot;&quot;
        print(f&quot;🎯 原始问题: {question}&quot;)
        print(&quot;-&quot; * 50)
        
        # 1. 生成查询
        queries = self.generate_queries(question)
        print(f&quot;📝 生成了 {len(queries)} 个查询:&quot;)
        for i, q in enumerate(queries, 1):
            print(f&quot;   {i}. {q}&quot;)
        
        # 2. 检索并融合
        fused_docs = self.retrieve_and_fuse(queries)
        
        # 3. 选择top-k
        top_docs = fused_docs[:top_k]
        print(f&quot;🏆 选择Top-{len(top_docs)} 文档:&quot;)
        for i, (doc, score) in enumerate(top_docs, 1):
            preview = doc.page_content[:100] + &quot;...&quot; if len(doc.page_content) &gt; 100 else doc.page_content
            print(f&quot;   {i}. [分数: {score:.4f}] {preview}&quot;)
        
        # 4. 生成答案
        context = &quot;\n\n&quot;.join([
            f&quot;[文档{i+1} 分数: {score:.4f}]\n{doc.page_content}&quot; 
            for i, (doc, score) in enumerate(top_docs)
        ])
        
        prompt = self._build_answer_prompt(question, context)
        answer = self._generate_answer(prompt, max_tokens, temperature)
        
        return {
            &quot;question&quot;: question,
            &quot;generated_queries&quot;: queries,
            &quot;num_docs&quot;: len(top_docs),
            &quot;top_scores&quot;: [score for _, score in top_docs],
            &quot;answer&quot;: answer,
            &quot;context_preview&quot;: context[:200] + &quot;...&quot; if len(context) &gt; 200 else context
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.4 实战测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;print(&quot;🚀 RAG-Fusion系统&quot;)
print(&quot;=&quot; * 60)

# 1. 创建实例
rag_fusion = RAGFusion(vectorstore, llm, k=60)

# 2. 执行查询
result = rag_fusion.query(&quot;什么是机械臂？它有哪些关键能力？&quot;)

# 3. 查看结果
print(&quot;\n&quot; + &quot;=&quot;*60)
print(&quot;生成的查询:&quot;)
for q in result[&quot;generated_queries&quot;]:
    print(f&quot;  - {q}&quot;)

print(f&quot;\nTop-{result[&apos;num_docs&apos;]} 文档分数:&quot;)
for i, score in enumerate(result[&quot;top_scores&quot;], 1):
    print(f&quot;  {i}. {score:.4f}&quot;)

print(&quot;\n答案:&quot;)
print(result[&quot;answer&quot;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 RAG-Fusion系统
============================================================
🎯 原始问题: 什么是机械臂？它有哪些关键能力？
--------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.57it/s]
📝 生成了 4 个查询:
   1. 什么是机械臂？它有哪些关键能力？
   2. 机械臂是什么？它的主要功能是什么？
   3. 机械臂有哪些重要的特点和功能？
   4. 机械臂的运作原理是什么？它有哪些关键能力？
🔍 执行 4 个查询...
   查询1: &apos;什么是机械臂？它有哪些关键能力？&apos; -&gt; 检索到 5 个文档
   查询2: &apos;机械臂是什么？它的主要功能是什么？&apos; -&gt; 检索到 5 个文档
   查询3: &apos;机械臂有哪些重要的特点和功能？&apos; -&gt; 检索到 5 个文档
   查询4: &apos;机械臂的运作原理是什么？它有哪些关键能力？&apos; -&gt; 检索到 5 个文档
✨ 融合后共 13 个唯一文档
🏆 选择Top-5 文档:
   1. [分数: 0.0640] 尽管，机械臂的运动速度可以设定为匀速运动， 避免加速度变化而造成对目标和夹爪之间动力学平衡关系的响。但是， 机械臂的
   2. [分数: 0.0489] 功能，其作用力的范围为 50N，精确度为 2.5N，准确度为 4N，扭矩范围为 10N·m，
精确度为 0.04N·m，准确度为 0.3N。机械臂本体自重 20.6kg，采用 220V交流供电，
   3. [分数: 0.0318] 机械臂 
可重复性  +-0.03mm  
有效负载  5 千克/11 磅 
工作半径  850mm/33.5 英寸 
自由度 6 个旋转关节
   4. [分数: 0.0315] ±180°/s，机械臂末端速度视各关节载荷与实际速度而定。机械臂本体的位姿可重复性
精度为 ±0.03mm。与其它型号机械臂不同的是，UR5e机械臂的工具法兰带有力感应
   5. [分数: 0.0161] 机记录该浮雕风格的接触表面图像，并使用光度立体算法
[55] 重建接触表面的深度图。
该传感器通过凝胶(gel)形变图像获取触觉信息，因此被称为Gelsight。为机器人手指
Processed prompts: 100%|██████████| 1/1 [00:02&amp;#x3C;00:00,  2.35s/it]
============================================================
生成的查询:
  - 什么是机械臂？它有哪些关键能力？
  - 机械臂是什么？它的主要功能是什么？
  - 机械臂有哪些重要的特点和功能？
  - 机械臂的运作原理是什么？它有哪些关键能力？

Top-5 文档分数:
  1. 0.0640
  2. 0.0489
  3. 0.0318
  4. 0.0315
  5. 0.0161

答案:
机械臂是一种能够按照预定程序自动完成特定任务的机器人手臂。它通常由多个关节和执行器组成，可以实现精确的运动控制和抓取操作。机械臂的关键能力包括：
             1. 运动控制：机械臂可以实现精确的运动控制，包括速度、位置和方向控制。
             2. 抓取操作：机械臂可以实现精确的抓取操作，包括抓取物体和释放物体。
             3. 自动化操作：机械臂可以实现自动化操作，包括自动装配、搬运和拆卸等。
             4. 精确度：机械臂可以实现高精度的操作，包括精确的定位和精确的抓取。
             5. 功能：机械臂可以实现多种功能，包括搬运、装配、拆卸、切割、焊接、喷涂等。
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.5 RAG-Fusion vs Multi-Query对比&lt;/h3&gt;
&lt;p&gt;| 特性 | Multi-Query | RAG-Fusion | 说明 |
|------|------------|------------|------|
| &lt;strong&gt;查询生成&lt;/strong&gt; | ✅ | ✅ | 都支持生成多个查询变体 |
| &lt;strong&gt;并行检索&lt;/strong&gt; | ✅ | ✅ | 都支持并行检索多个查询 |
| &lt;strong&gt;智能排序&lt;/strong&gt; | ❌ | ✅ RRF算法 | RAG-Fusion使用倒数排序融合 |
| &lt;strong&gt;结果质量&lt;/strong&gt; | 中等 | 高 | RRF提升检索质量 |
| &lt;strong&gt;计算成本&lt;/strong&gt; | 低 | 中等 | RRF增加计算开销 |
| &lt;strong&gt;适用场景&lt;/strong&gt; | 一般检索 | 高质量需求 | 根据需求选择 |&lt;/p&gt;
&lt;h3&gt;2.6 技术选型建议&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def select_query_optimization_method(use_case: str, quality_requirement: str) -&gt; str:
    &quot;&quot;&quot;根据场景选择查询优化方法&quot;&quot;&quot;
    if use_case == &quot;simple_qa&quot; and quality_requirement == &quot;balanced&quot;:
        return &quot;Multi-Query&quot;  # 平衡效果与成本
    elif use_case == &quot;research&quot; and quality_requirement == &quot;high&quot;:
        return &quot;RAG-Fusion&quot;   # 追求最高质量
    elif use_case == &quot;real_time&quot; and quality_requirement == &quot;fast&quot;:
        return &quot;Single-Query&quot; # 追求最低延迟
    else:
        return &quot;Multi-Query&quot;  # 默认选择
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;总结&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;Multi-Query vs RAG-Fusion 实战对比&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;| 指标 | Multi-Query | RAG-Fusion | 胜出方 |
|------|------------|------------|--------|
| &lt;strong&gt;检索文档数&lt;/strong&gt; | 6个 | 13个 | RAG-Fusion |
| &lt;strong&gt;答案质量&lt;/strong&gt; | 良好 | 优秀 | RAG-Fusion |&lt;br&gt;
| &lt;strong&gt;响应时间&lt;/strong&gt; | 较快 | 稍慢 | Multi-Query |
| &lt;strong&gt;实现复杂度&lt;/strong&gt; | 简单 | 中等 | Multi-Query |&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;🎯 使用建议&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;日常问答&lt;/strong&gt;：使用Multi-Query，平衡效果与成本&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;研究分析&lt;/strong&gt;：使用RAG-Fusion，追求最高质量&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实时系统&lt;/strong&gt;：使用Single-Query，追求最低延迟&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;资源充足&lt;/strong&gt;：RAG-Fusion + 缓存优化&lt;/li&gt;
&lt;/ol&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;经验总结&lt;/strong&gt;：RRF算法通过&lt;strong&gt;多查询结果融合&lt;/strong&gt;，显著提升了检索质量，特别适合对答案准确性要求高的场景。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;在下一部分，我们将探讨更高级的&lt;strong&gt;查询分解（Decomposition）&lt;strong&gt;和&lt;/strong&gt;Step Back提示&lt;/strong&gt;技术，处理复杂的多步骤问题。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part3: Query Decomposition - 查询分解&lt;/h2&gt;
&lt;h3&gt;3.1 核心概念&lt;/h3&gt;
&lt;p&gt;对于复杂的多步骤问题，Query Decomposition技术将其&lt;strong&gt;分解为多个子问题&lt;/strong&gt;，分别回答后再合成最终答案。这种方法特别适合处理需要多角度分析的复杂查询。&lt;/p&gt;
&lt;h4&gt;两种分解策略对比&lt;/h4&gt;
&lt;h5&gt;递归分解（Answer Recursively）&lt;/h5&gt;
&lt;pre&gt;&lt;code&gt;复杂问题: &quot;比较GPT-3和GPT-4在多模态能力上的差异&quot;
    ↓
子问题1: &quot;GPT-3有哪些能力？&quot;
    ↓ 检索 + 回答
答案1: &quot;GPT-3主要是文本模型...&quot;
    ↓
子问题2: &quot;GPT-4有哪些新能力？&quot; (基于答案1)
    ↓ 检索 + 回答  
答案2: &quot;GPT-4增加了图像理解...&quot;
    ↓
综合答案: &quot;GPT-3仅支持文本，而GPT-4...&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;h5&gt;并行分解（Answer Individually）&lt;/h5&gt;
&lt;pre&gt;&lt;code&gt;复杂问题: &quot;比较Python和JavaScript在Web开发中的优劣&quot;
    ↓
子问题1: &quot;Python在Web开发中的优势&quot;
子问题2: &quot;JavaScript在Web开发中的优势&quot;  
子问题3: &quot;Python在Web开发中的劣势&quot;
子问题4: &quot;JavaScript在Web开发中的劣势&quot;
    ↓
并行检索 + 回答
    ↓
综合所有答案
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.2 递归分解实现&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from typing import List, Dict, Any
from vllm import SamplingParams

class RecursiveDecomposition:
    &quot;&quot;&quot;使用ChatML格式的递归查询分解&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm):
        self.vectorstore = vectorstore
        self.retriever = vectorstore.as_retriever()
        self.llm = llm
        print(&quot;✅ 递归分解系统初始化完成&quot;)
        
    def decompose_query(self, question: str) -&gt; List[str]:
        &quot;&quot;&quot;使用ChatML格式分解复杂查询为子问题&quot;&quot;&quot;
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个AI助手，擅长将复杂问题分解为简单的子问题。
            
            请将以下复杂问题分解为2-4个简单的子问题：
            
            要求：
            1. 子问题应该按逻辑顺序排列
            2. 每个子问题都应该是独立可回答的
            3. 每行一个问题，不要编号&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            复杂问题：{question}
            
            请分解为子问题，每行一个：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=512,
            temperature=0.5,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        response = outputs[0].outputs[0].text.strip()
        
        # 解析子问题
        sub_questions = []
        lines = response.split(&apos;\n&apos;)
        for line in lines:
            line = line.strip()
            if line and len(line) &gt; 5:  # 过滤太短的子问题
                # 移除可能的编号标记
                if line[0].isdigit() and (line[1] == &apos;.&apos; or line[1] == &apos;、&apos;):
                    question = line[2:].strip()
                else:
                    question = line
                sub_questions.append(question)
        
        # 确保有子问题，如果没有则使用备选方案
        if not sub_questions:
            sub_questions = self._fallback_decomposition(question)
        
        return sub_questions[:4]  # 最多返回4个子问题
    
    def _fallback_decomposition(self, question: str) -&gt; List[str]:
        &quot;&quot;&quot;备选分解方案&quot;&quot;&quot;
        # 简单的基于规则的分解
        sub_questions = []
        
        if &quot;比较&quot; in question and &quot;和&quot; in question:
            # 比较类问题
            parts = question.split(&quot;比较&quot;)[1].split(&quot;和&quot;)
            if len(parts) &gt;= 2:
                sub_questions.append(f&quot;{parts[0].strip()}的主要特点是什么？&quot;)
                sub_questions.append(f&quot;{parts[1].strip()}的主要特点是什么？&quot;)
                sub_questions.append(&quot;它们的主要差异是什么？&quot;)
        
        elif &quot;步骤&quot; in question or &quot;流程&quot; in question:
            # 步骤类问题
            sub_questions.append(&quot;第一步是什么？&quot;)
            sub_questions.append(&quot;关键步骤有哪些？&quot;)
            sub_questions.append(&quot;最终目标是什么？&quot;)
        
        else:
            # 通用分解
            sub_questions.append(&quot;基本概念是什么？&quot;)
            sub_questions.append(&quot;主要应用场景有哪些？&quot;)
            sub_questions.append(&quot;优势和局限性是什么？&quot;)
        
        return sub_questions
    
    def answer_sub_question(self, question: str, context: str = &quot;&quot;) -&gt; str:
        &quot;&quot;&quot;使用ChatML格式回答单个子问题&quot;&quot;&quot;
        # 检索相关文档
        docs = self.retriever.get_relevant_documents(question)
        doc_context = &quot;\n\n&quot;.join([doc.page_content for doc in docs[:3]])
        
        # 构建提示词
        if context:
            full_context = f&quot;已知信息:\n{context}\n\n相关文档:\n{doc_context}&quot;
        else:
            full_context = f&quot;相关文档:\n{doc_context}&quot;
        
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个专业的AI助手。请基于以下信息简洁地回答问题：
            
            {full_context}
            
            请遵循：
            1. 只使用提供的信息
            2. 保持回答简洁准确
            3. 如果信息不足，请说明&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            问题：{question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=300,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        return outputs[0].outputs[0].text.strip()
    
    def query(self, question: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;执行递归分解查询&quot;&quot;&quot;
        print(f&quot;🎯 原始问题: {question}&quot;)
        print(&quot;-&quot; * 50)
        
        # 1. 分解问题
        sub_questions = self.decompose_query(question)
        print(f&quot;🔍 分解为 {len(sub_questions)} 个子问题:&quot;)
        for i, sq in enumerate(sub_questions, 1):
            print(f&quot;   {i}. {sq}&quot;)
        
        # 2. 递归回答
        accumulated_context = &quot;&quot;
        sub_answers = []
        
        print(&quot;\n💡 逐步回答子问题:&quot;)
        for i, sq in enumerate(sub_questions, 1):
            print(f&quot;\n   子问题{i}: {sq}&quot;)
            answer = self.answer_sub_question(sq, accumulated_context)
            sub_answers.append(answer)
            accumulated_context += f&quot;\n\n问题{i}: {sq}\n答案: {answer}&quot;
            print(f&quot;   答案: {answer[:100]}...&quot;)
        
        # 3. 综合最终答案
        final_prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个AI助手，需要基于子问题的答案来综合回答原始问题。
            
            原始问题：{question}
            
            子问题和答案：
            {self._format_sub_qa(sub_questions, sub_answers)}
            
            请基于以上子问题的答案，给出一个完整、连贯的综合答案。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            请综合回答原始问题：{question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=500,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([final_prompt], sampling_params)
        final_answer = outputs[0].outputs[0].text.strip()
        
        return {
            &quot;question&quot;: question,
            &quot;sub_questions&quot;: sub_questions,
            &quot;sub_answers&quot;: sub_answers,
            &quot;final_answer&quot;: final_answer,
            &quot;num_sub_questions&quot;: len(sub_questions)
        }
    
    def _format_sub_qa(self, questions: List[str], answers: List[str]) -&gt; str:
        &quot;&quot;&quot;格式化子问题和答案&quot;&quot;&quot;
        formatted = &quot;&quot;
        for i, (q, a) in enumerate(zip(questions, answers), 1):
            formatted += f&quot;子问题{i}: {q}\n答案{i}: {a}\n\n&quot;
        return formatted.strip()
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.3 递归分解实战测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;print(&quot;🚀 ChatML格式的递归分解系统&quot;)
print(&quot;=&quot; * 60)

# 1. 创建实例
decomp = RecursiveDecomposition(vectorstore, llm)

# 2. 执行查询
result = decomp.query(&quot;比较抓取检测和滑动检测的主要差异&quot;)

# 3. 查看结果
print(&quot;\n&quot; + &quot;=&quot;*60)
print(&quot;📋 子问题分解:&quot;)
for i, (q, a) in enumerate(zip(result[&quot;sub_questions&quot;], result[&quot;sub_answers&quot;]), 1):
    print(f&quot;{i}. 问题: {q}&quot;)
    print(f&quot;   答案: {a[:100]}...&quot;)
    print()

print(&quot;💡 最终综合答案:&quot;)
print(result[&quot;final_answer&quot;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 ChatML格式的递归分解系统
============================================================
✅ 递归分解系统初始化完成
🎯 原始问题: 比较抓取检测和滑动检测的主要差异
--------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.29s/it]
🔍 分解为 4 个子问题:
   1. 抓取检测和滑动检测都是用于检测物体在图像中的位置的方法，它们之间有什么主要差异？
   2. 抓取检测主要依赖于物体的形状和颜色，而滑动检测主要依赖于物体的运动特征，这两者之间有什么不同？
   3. 抓取检测和滑动检测在检测精度和速度方面有什么差异？
   4. 抓取检测和滑动检测在应用场景上有什么不同？

💡 逐步回答子问题:

   子问题1: 抓取检测和滑动检测都是用于检测物体在图像中的位置的方法，它们之间有什么主要差异？
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.26s/it]
   答案: 抓取检测和滑动检测都是用于检测物体在图像中的位置的方法，但它们之间存在一些主要差异。抓取检测主要关注物体在图像中的位置变化，而滑动检测则更关注物体在图像中的运动状态。抓取检测通常使用视觉信息，而滑动检...

   子问题2: 抓取检测主要依赖于物体的形状和颜色，而滑动检测主要依赖于物体的运动特征，这两者之间有什么不同？
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.04it/s]
   答案: 抓取检测主要依赖于物体的形状和颜色，而滑动检测主要依赖于物体的运动特征。抓取检测通常使用视觉信息，而滑动检测则可以使用触觉信息。此外，抓取检测通常需要精确的物体定位，而滑动检测则更关注物体的运动轨迹。...

   子问题3: 抓取检测和滑动检测在检测精度和速度方面有什么差异？
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.21s/it]
   答案: 抓取检测和滑动检测在检测精度和速度方面存在差异。抓取检测主要关注物体在图像中的位置变化，而滑动检测则更关注物体在图像中的运动状态。抓取检测通常使用视觉信息，而滑动检测则可以使用触觉信息。此外，抓取检测...

   子问题4: 抓取检测和滑动检测在应用场景上有什么不同？
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.03it/s]
   答案: 抓取检测和滑动检测在应用场景上存在一些主要差异。抓取检测主要应用于物体定位和抓取，例如机器人抓取、无人机抓取等。而滑动检测则主要应用于物体运动状态的检测，例如物体滑动检测、物体滑动轨迹检测等。...
Processed prompts: 100%|██████████| 1/1 [00:02&amp;#x3C;00:00,  2.11s/it]

============================================================
📋 子问题分解:
1. 问题: 抓取检测和滑动检测都是用于检测物体在图像中的位置的方法，它们之间有什么主要差异？
   答案: 抓取检测和滑动检测都是用于检测物体在图像中的位置的方法，但它们之间存在一些主要差异。抓取检测主要关注物体在图像中的位置变化，而滑动检测则更关注物体在图像中的运动状态。抓取检测通常使用视觉信息，而滑动检...

2. 问题: 抓取检测主要依赖于物体的形状和颜色，而滑动检测主要依赖于物体的运动特征，这两者之间有什么不同？
   答案: 抓取检测主要依赖于物体的形状和颜色，而滑动检测主要依赖于物体的运动特征。抓取检测通常使用视觉信息，而滑动检测则可以使用触觉信息。此外，抓取检测通常需要精确的物体定位，而滑动检测则更关注物体的运动轨迹。...

3. 问题: 抓取检测和滑动检测在检测精度和速度方面有什么差异？
   答案: 抓取检测和滑动检测在检测精度和速度方面存在差异。抓取检测主要关注物体在图像中的位置变化，而滑动检测则更关注物体在图像中的运动状态。抓取检测通常使用视觉信息，而滑动检测则可以使用触觉信息。此外，抓取检测...

4. 问题: 抓取检测和滑动检测在应用场景上有什么不同？
   答案: 抓取检测和滑动检测在应用场景上存在一些主要差异。抓取检测主要应用于物体定位和抓取，例如机器人抓取、无人机抓取等。而滑动检测则主要应用于物体运动状态的检测，例如物体滑动检测、物体滑动轨迹检测等。...

💡 最终综合答案:
抓取检测和滑动检测都是用于检测物体在图像中的位置的方法，但它们之间存在一些主要差异。抓取检测主要关注物体在图像中的位置变化，而滑动检测则更关注物体在图像中的运动状态。抓取检测通常使用视觉信息，而滑动检测则可以使用触觉信息。此外，抓取检测通常需要精确的物体定位，而滑动检测则更关注物体的运动轨迹。抓取检测和滑动检测在检测精度和速度方面存在差异。抓取检测主要应用于物体定位和抓取，例如机器人抓取、无人机抓取等。而滑动检测则主要应用于物体运动状态的检测，例如物体滑动检测、物体滑动轨迹检测等。
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.4 并行分解实现&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import asyncio
from typing import List, Dict, Any
from vllm import SamplingParams
from concurrent.futures import ThreadPoolExecutor
import time

class TrueAsyncParallelDecomposition:
    &quot;&quot;&quot;真正的异步并行分解&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm, max_workers: int = 4):
        self.vectorstore = vectorstore
        self.llm = llm
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        print(&quot;✅ 真正异步并行分解系统初始化完成&quot;)
    
    async def _generate(self, prompt: str, max_tokens: int = 512) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;异步生成&quot;&quot;&quot;
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=0.3,
            top_p=0.8
        )
        
        try:
            loop = asyncio.get_event_loop()
            outputs = await loop.run_in_executor(
                self.executor,
                lambda: self.llm.generate([prompt], sampling_params)
            )
            
            if outputs and outputs[0].outputs:
                text = outputs[0].outputs[0].text.strip()
                return {&quot;success&quot;: True, &quot;text&quot;: text}
            return {&quot;success&quot;: False, &quot;error&quot;: &quot;生成失败&quot;}
                
        except Exception as e:
            return {&quot;success&quot;: False, &quot;error&quot;: str(e)}
    
    async def answer_sub_question(self, question: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;异步回答子问题&quot;&quot;&quot;
        start_time = time.time()
        
        try:
            # 1. 异步检索
            retriever = self.vectorstore.as_retriever()
            docs = await retriever.aget_relevant_documents(question)
            retrieval_time = time.time() - start_time
            
            if not docs:
                return {
                    &quot;question&quot;: question,
                    &quot;answer&quot;: &quot;未找到相关文档&quot;,
                    &quot;success&quot;: False,
                    &quot;retrieval_time&quot;: retrieval_time
                }
            
            # 2. 构建上下文
            context = &quot;\n&quot;.join([doc.page_content[:100] for doc in docs[:2]])
            
            # 3. 构建提示词
            prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
                回答问题：&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;user
                信息：{context}
                问题：{question}&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;assistant
                &quot;&quot;&quot;
            
            # 4. 异步生成
            gen_start = time.time()
            result = await self._generate(prompt, max_tokens=512)
            generation_time = time.time() - gen_start
            
            if result[&quot;success&quot;]:
                return {
                    &quot;question&quot;: question,
                    &quot;answer&quot;: result[&quot;text&quot;],
                    &quot;success&quot;: True,
                    &quot;retrieval_time&quot;: retrieval_time,
                    &quot;generation_time&quot;: generation_time,
                    &quot;total_time&quot;: time.time() - start_time
                }
            else:
                return {
                    &quot;question&quot;: question,
                    &quot;answer&quot;: f&quot;生成失败: {result[&apos;error&apos;]}&quot;,
                    &quot;success&quot;: False,
                    &quot;retrieval_time&quot;: retrieval_time,
                    &quot;generation_time&quot;: generation_time,
                    &quot;total_time&quot;: time.time() - start_time
                }
            
        except Exception as e:
            return {
                &quot;question&quot;: question,
                &quot;answer&quot;: f&quot;错误: {str(e)}&quot;,
                &quot;success&quot;: False,
                &quot;retrieval_time&quot;: 0,
                &quot;generation_time&quot;: 0,
                &quot;total_time&quot;: time.time() - start_time
            }
    
    async def query_parallel(self, question: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;真正的异步并行查询&quot;&quot;&quot;
        print(f&quot;🎯 问题: {question}&quot;)
        print(&quot;=&quot; * 60)
        
        overall_start = time.time()
        
        # 1. 分解问题
        sub_questions = [
            &quot;抓取检测是什么？&quot;,
            &quot;滑动检测是什么？&quot;,
            &quot;机械臂任务中抓取检测和滑动检测的关联是什么？&quot;
        ]
        
        print(f&quot;🔍 分解为 {len(sub_questions)} 个子问题:&quot;)
        for i, q in enumerate(sub_questions, 1):
            print(f&quot;   {i}. {q}&quot;)
        
        # 2. 🔥 真正的并行执行（关键修复）
        print(f&quot;⚡ 并行执行 {len(sub_questions)} 个子问题...&quot;)
        tasks = [self.answer_sub_question(q) for q in sub_questions]
        sub_results = await asyncio.gather(*tasks)  # 🔥 并行执行
        
        parallel_time = time.time() - overall_start
        
        # 3. 分析结果
        successful = [r for r in sub_results if r[&apos;success&apos;]]
        
        print(f&quot;\n📊 执行统计:&quot;)
        print(f&quot;   ✅ 成功: {len(successful)}/{len(sub_questions)}&quot;)
        print(f&quot;   ⏱️  并行时间: {parallel_time:.2f}s&quot;)
        
        # 显示每个子问题的执行时间
        print(&quot;\n💡 子问题执行详情:&quot;)
        for i, result in enumerate(sub_results, 1):
            status = &quot;✅&quot; if result[&apos;success&apos;] else &quot;❌&quot;
            print(f&quot;   {i}. {status} 检索: {result[&apos;retrieval_time&apos;]:.2f}s, &quot;
                  f&quot;生成: {result[&apos;generation_time&apos;]:.2f}s, &quot;
                  f&quot;总计: {result[&apos;total_time&apos;]:.2f}s&quot;)
        
        # 4. 综合最终答案
        print(&quot;\n🧠 综合最终答案...&quot;)
        
        if successful:
            qa_context = &quot;\n&quot;.join([
                f&quot;Q{i}: {r[&apos;question&apos;]}\nA{i}: {r[&apos;answer&apos;]}&quot; 
                for i, r in enumerate(successful, 1)
            ])
            
            synthesis_prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
							基于以下问答信息综合回答原始问题：&amp;#x3C;|im_end|&gt;
							&amp;#x3C;|im_start|&gt;user
							原始问题：{question}
							
							子问题和答案：
							{qa_context}
							
							请给出综合答案：&amp;#x3C;|im_end|&gt;
							&amp;#x3C;|im_start|&gt;assistant
							&quot;&quot;&quot;
							        else:
							            synthesis_prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
							请回答以下问题：&amp;#x3C;|im_end|&gt;
							&amp;#x3C;|im_start|&gt;user
							问题：{question}&amp;#x3C;|im_end|&gt;
							&amp;#x3C;|im_start|&gt;assistant
							&quot;&quot;&quot;
        
        final_result = await self._generate(synthesis_prompt, max_tokens=512)
        total_time = time.time() - overall_start
        
        if final_result[&quot;success&quot;]:
            final_answer = final_result[&quot;text&quot;]
        else:
            final_answer = f&quot;综合失败: {final_result[&apos;error&apos;]}&quot;
        
        print(f&quot;⏱️  总耗时: {total_time:.2f}s&quot;)
        print(f&quot;📊 并行效率: {parallel_time/total_time*100:.1f}%&quot;)
        
        return {
            &quot;question&quot;: question,
            &quot;final_answer&quot;: final_answer,
            &quot;sub_questions&quot;: sub_questions,
            &quot;sub_results&quot;: sub_results,
            &quot;metrics&quot;: {
                &quot;total_time&quot;: total_time,
                &quot;parallel_time&quot;: parallel_time,
                &quot;success_rate&quot;: f&quot;{len(successful)}/{len(sub_questions)}&quot;,
                &quot;parallel_efficiency&quot;: f&quot;{parallel_time/total_time*100:.1f}%&quot;
            }
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.5 性能对比测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# === 性能对比测试 ===
async def performance_comparison():
    &quot;&quot;&quot;性能对比：并行 vs 串行&quot;&quot;&quot;
    print(&quot;📊 性能对比测试&quot;)
    print(&quot;=&quot; * 60)
    
    decomp = TrueAsyncParallelDecomposition(vectorstore, llm)
    
    try:
        question = &quot;抓取检测和滑动检测在机械臂任务中的关联&quot;
        
        # 并行执行
        print(&quot;🚀 并行执行测试...&quot;)
        parallel_start = time.time()
        parallel_result = await decomp.query_parallel(question)
        parallel_time = time.time() - parallel_start
        
        # 串行执行（模拟）
        print(&quot;\n🐌 串行执行测试...&quot;)
        serial_start = time.time()
        
        sub_questions = [
            &quot;抓取检测是什么？&quot;,
            &quot;滑动检测是什么？&quot;, 
            &quot;机械臂任务中抓取检测和滑动检测的关联是什么？&quot;
        ]
        
        serial_results = []
        for q in sub_questions:
            result = await decomp.answer_sub_question(q)  # 顺序执行
            serial_results.append(result)
        
        serial_time = time.time() - serial_start
        
        print(f&quot;\n📈 性能对比结果:&quot;)
        print(f&quot;   🚀 并行时间: {parallel_time:.2f}s&quot;)
        print(f&quot;   🐌 串行时间: {serial_time:.2f}s&quot;)
        print(f&quot;   ⚡ 加速比: {serial_time/parallel_time:.2f}x&quot;)
        print(f&quot;   📊 效率提升: {(serial_time-parallel_time)/serial_time*100:.1f}%&quot;)
        
        return {
            &quot;parallel_time&quot;: parallel_time,
            &quot;serial_time&quot;: serial_time,
            &quot;speedup&quot;: serial_time / parallel_time
        }
        
    finally:
        await decomp.close()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;性能对比输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📊 性能对比测试
============================================================
✅ 真正异步并行分解系统初始化完成
🚀 并行执行测试...
🎯 问题: 抓取检测和滑动检测在机械臂任务中的关联
============================================================
🔍 分解为 3 个子问题:
   1. 抓取检测是什么？
   2. 滑动检测是什么？
   3. 机械臂任务中抓取检测和滑动检测的关联是什么？
⚡ 并行执行 3 个子问题...

📊 执行统计:
   ✅ 成功: 0/3
   ⏱️  并行时间: 0.20s

💡 子问题执行详情:
   1. ❌ 检索: 0.05s, 生成: 0.01s, 总计: 0.06s
   2. ❌ 检索: 0.05s, 生成: 0.13s, 总计: 0.17s
   3. ❌ 检索: 0.05s, 生成: 0.15s, 总计: 0.20s

🧠 综合最终答案...
⏱️  总耗时: 1.29s
📊 并行效率: 15.7%

🐌 串行执行测试...
📈 性能对比结果:
   🚀 并行时间: 1.29s
   🐌 串行时间: 2.76s
   ⚡ 加速比: 2.14x
   📊 效率提升: 53.3%
✅ 资源已关闭
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.6 分解策略对比分析&lt;/h3&gt;
&lt;p&gt;| 特性 | 递归分解 | 并行分解 | 说明 |
|------|---------|---------|------|
| &lt;strong&gt;执行方式&lt;/strong&gt; | 顺序执行 | 并行执行 | 并行分解速度更快 |
| &lt;strong&gt;速度&lt;/strong&gt; | 较慢 | 快 ⚡ | 并行可大幅加速 |
| &lt;strong&gt;子问题依赖&lt;/strong&gt; | 支持 | 不支持 | 递归适合有逻辑顺序的问题 |
| &lt;strong&gt;适用场景&lt;/strong&gt; | 有逻辑顺序的问题 | 独立子问题 | 根据问题特点选择 |
| &lt;strong&gt;实现复杂度&lt;/strong&gt; | 中等 | 高 | 并行需要异步编程 |
| &lt;strong&gt;资源消耗&lt;/strong&gt; | 低 | 高 | 并行需要更多计算资源 |&lt;/p&gt;
&lt;h3&gt;3.7 技术选型指南&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def select_decomposition_strategy(question: str) -&gt; str:
    &quot;&quot;&quot;根据问题特点选择分解策略&quot;&quot;&quot;
    
    # 判断问题类型的关键词
    sequential_keywords = [&quot;步骤&quot;, &quot;流程&quot;, &quot;顺序&quot;, &quot;首先&quot;, &quot;然后&quot;]
    comparison_keywords = [&quot;比较&quot;, &quot;对比&quot;, &quot;差异&quot;, &quot;优缺点&quot;]
    independent_keywords = [&quot;分别&quot;, &quot;各自&quot;, &quot;独立&quot;]
    
    # 检查关键词
    has_sequential = any(keyword in question for keyword in sequential_keywords)
    has_comparison = any(keyword in question for keyword in comparison_keywords)
    has_independent = any(keyword in question for keyword in independent_keywords)
    
    if has_sequential:
        return &quot;递归分解&quot;  # 有明确顺序的问题
    elif has_independent or has_comparison:
        return &quot;并行分解&quot;  # 独立子问题
    else:
        return &quot;递归分解&quot;  # 默认选择
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.8 最佳实践建议&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;✅ 递归分解适用场景&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;步骤性指导&lt;/strong&gt;：&quot;如何安装Python环境？&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;逻辑推理&lt;/strong&gt;：&quot;为什么机器学习需要大量数据？&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;因果分析&lt;/strong&gt;：&quot;气候变化对农业的影响是什么？&quot;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;✅ 并行分解适用场景&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;多角度比较&lt;/strong&gt;：&quot;Python vs Java的优缺点&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;独立概念&lt;/strong&gt;：&quot;机器学习的三大类型是什么？&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;综合分析&lt;/strong&gt;：&quot;人工智能在医疗、金融、教育中的应用&quot;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;⚡ 性能优化技巧&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 1. 动态调整并行度
def adjust_parallelism(sub_questions: List[str]) -&gt; int:
    &quot;&quot;&quot;根据子问题数量调整并行度&quot;&quot;&quot;
    if len(sub_questions) &amp;#x3C;= 2:
        return 2  # 少量问题，低并行度
    elif len(sub_questions) &amp;#x3C;= 4:
        return 4  # 中等问题，中等并行度
    else:
        return min(8, len(sub_questions))  # 大量问题，高并行度

# 2. 智能超时控制
async def with_timeout(coroutine, timeout: float):
    &quot;&quot;&quot;为异步操作添加超时控制&quot;&quot;&quot;
    try:
        return await asyncio.wait_for(coroutine, timeout=timeout)
    except asyncio.TimeoutError:
        return {&quot;success&quot;: False, &quot;error&quot;: &quot;超时&quot;}
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;总结&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;查询分解技术核心价值&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;处理复杂问题&lt;/strong&gt;：将复杂查询分解为可管理的子问题&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;提高答案质量&lt;/strong&gt;：每个子问题得到专门回答，综合答案更全面&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;灵活的策略选择&lt;/strong&gt;：根据问题特点选择递归或并行分解&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;🎯 实践建议&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;简单问题&lt;/strong&gt;：直接使用基础检索，无需分解&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;逻辑复杂问题&lt;/strong&gt;：使用递归分解，保持问题间的依赖关系&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多角度问题&lt;/strong&gt;：使用并行分解，充分利用计算资源&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;性能敏感场景&lt;/strong&gt;：结合缓存和异步优化&lt;/li&gt;
&lt;/ol&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;经验总结&lt;/strong&gt;：查询分解技术显著提升了RAG系统处理复杂问题的能力，特别是对于需要多角度分析的学术和技术问题。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;在下一部分，我们将探讨更高级的&lt;strong&gt;Step Back提示&lt;/strong&gt;和&lt;strong&gt;HyDE技术&lt;/strong&gt;，进一步提升RAG系统的推理能力和检索质量。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part4: Step Back Prompting - 抽象化提问&lt;/h2&gt;
&lt;h3&gt;4.1 核心概念&lt;/h3&gt;
&lt;p&gt;Step Back Prompting 先提出一个&lt;strong&gt;更抽象、更概括的问题&lt;/strong&gt;，获取背景知识后，再回答原始具体问题。这种方法通过&lt;strong&gt;两步推理&lt;/strong&gt;提升答案的质量和准确性。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;为什么需要 Step Back？&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;❌ 直接回答可能缺乏背景&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 原始问题: &quot;Transformer中的Multi-Head Attention有几个头？&quot;

# 直接检索 → 可能找不到确切答案
# 文档可能只描述了原理，没说具体数字
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;✅ Step Back 后效果更好&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# Step 1: Step Back问题
&quot;Transformer架构的基本组成是什么？&quot;

# Step 2: 获取背景知识  
&quot;Transformer由编码器和解码器组成，使用Multi-Head Attention...&quot;

# Step 3: 结合背景回答原始问题
&quot;根据原始论文，使用8个attention头...&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.2 Step Back RAG 系统实现&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import asyncio
from typing import Dict, Any
from vllm import SamplingParams

class StepBackRAG:
    &quot;&quot;&quot;使用ChatML格式的Step Back RAG系统&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm):
        self.vectorstore = vectorstore
        self.retriever = vectorstore.as_retriever()
        self.llm = llm
        print(&quot;✅ Step Back RAG系统初始化完成&quot;)
    
    def generate_step_back_question(self, question: str) -&gt; str:
        &quot;&quot;&quot;生成Step Back问题&quot;&quot;&quot;
        # ChatML格式提示词
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个AI助手，擅长将具体问题转化为更抽象、更概括的问题。
            
            给定一个具体问题，请生成一个能够提供背景知识的更抽象问题。
            
            示例：
            具体问题: &quot;GPT-4的上下文长度是多少？&quot;
            Step Back问题: &quot;GPT-4的主要技术特性有哪些？&quot;
            
            具体问题: &quot;Python中的装饰器如何工作？&quot;
            Step Back问题: &quot;Python中的元编程概念是什么？&quot;
            
            现在请为以下具体问题生成Step Back问题：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            具体问题: {question}
            Step Back问题:&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=512,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs and outputs[0].outputs:
            step_back_q = outputs[0].outputs[0].text.strip()
        else:
            step_back_q = &quot;生成失败&quot;
        
        return step_back_q
    
    def retrieve_documents(self, question: str, k: int = 3) -&gt; str:
        &quot;&quot;&quot;检索相关文档并构建上下文&quot;&quot;&quot;
        docs = self.retriever.get_relevant_documents(question)
        context = &quot;\n\n&quot;.join([doc.page_content for doc in docs[:k]])
        return context
    
    def generate_final_answer(self, question: str, background: str, specific: str) -&gt; str:
        &quot;&quot;&quot;结合背景知识和具体信息生成最终答案&quot;&quot;&quot;
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个专业的AI助手。请基于以下背景知识和具体信息来回答问题。
            
            背景知识提供了相关领域的理论框架和概念理解，具体信息包含了问题的直接相关内容。
            
            请先理解背景知识，再结合具体信息给出准确、全面的答案。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            背景知识:
            {background}
            
            具体信息:
            {specific}
            
            原始问题: {question}
            
            请基于以上信息给出详细的回答：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=512,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs and outputs[0].outputs:
            return outputs[0].outputs[0].text.strip()
        else:
            return &quot;生成失败&quot;
    
    def query(self, question: str) -&gt; dict:
        &quot;&quot;&quot;执行Step Back RAG查询&quot;&quot;&quot;
        print(f&quot;❓ 原始问题: {question}&quot;)
        print(&quot;-&quot; * 50)
        
        # 1. 生成Step Back问题
        step_back_q = self.generate_step_back_question(question)
        print(f&quot;📚 Step Back问题: {step_back_q}&quot;)
        
        # 2. 检索背景知识（Step Back问题）
        background = self.retrieve_documents(step_back_q, k=3)
        print(f&quot;🔍 检索到背景知识，长度: {len(background)} 字符&quot;)
        
        # 3. 检索具体信息（原始问题）
        specific = self.retrieve_documents(question, k=3)
        print(f&quot;📄 检索到具体信息，长度: {len(specific)} 字符&quot;)
        
        # 4. 结合两者生成答案
        answer = self.generate_final_answer(question, background, specific)
        
        return {
            &quot;question&quot;: question,
            &quot;step_back_question&quot;: step_back_q,
            &quot;background_preview&quot;: background[:200] + &quot;...&quot; if len(background) &gt; 200 else background,
            &quot;specific_preview&quot;: specific[:200] + &quot;...&quot; if len(specific) &gt; 200 else specific,
            &quot;answer&quot;: answer
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.3 实战测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 使用示例
print(&quot;🚀 Step Back RAG系统&quot;)
print(&quot;=&quot; * 60)

step_back_rag = StepBackRAG(vectorstore, llm)
result = step_back_rag.query(&quot;滑动检测的操作流程？&quot;)

print(&quot;\n&quot; + &quot;=&quot;*60)
print(&quot;💡 答案:&quot;)
print(result[&quot;answer&quot;])

print(f&quot;\n📊 查询详情:&quot;)
print(f&quot;Step Back问题: {result[&apos;step_back_question&apos;]}&quot;)
print(f&quot;背景知识预览: {result[&apos;background_preview&apos;]}&quot;)
print(f&quot;具体信息预览: {result[&apos;specific_preview&apos;]}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 Step Back RAG系统
============================================================
✅ Step Back RAG系统初始化完成
❓ 原始问题: 滑动检测的操作流程？
--------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  7.72it/s]
📚 Step Back问题: 什么是滑动检测？
🔍 检索到背景知识，长度: 1250 字符
📄 检索到具体信息，长度: 980 字符
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.46it/s]

============================================================
💡 答案:
滑动检测的操作流程如下：

1. 滑动发生时，滑动检测算法会触发一个滑动信号。
2. 当目标发生滑动时，控制器会按照原始操作流程进行滑动检测。

📊 查询详情:
Step Back问题: 什么是滑动检测？
背景知识预览: 滑动检测是一种用于检测物体是否发生滑动的算法。它通过比较帧间的标准差来判断物体是否发生滑动。如果帧间的标准差超过设定的阈值，那么就认为物体发生了滑动。滑动检测算法通常用于机械臂的抓取任务中，以防止物体在抓取过程中滑落...
具体信息预览: 滑动检测的操作流程包括以下几个步骤：首先，系统会初始化传感器并设置阈值参数；然后，实时监测触觉信号的变化；当检测到滑动信号时，系统会触发相应的控制策略来调整抓取力...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.4 Step Back 的优势分析&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;✅ 适用场景&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;问题需要背景知识&lt;/strong&gt;：技术概念、理论框架&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;直接检索效果不好&lt;/strong&gt;：查询太具体或术语不匹配&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;问题过于技术性强&lt;/strong&gt;：需要理论基础支撑&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;需要概念理解&lt;/strong&gt;：而不仅仅是事实回答&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;❌ 不适合的场景&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;简单事实性问题&lt;/strong&gt;：&quot;今天天气怎么样？&quot;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;已有充足直接信息&lt;/strong&gt;：文档库中已有明确答案&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实时数据查询&lt;/strong&gt;：需要最新实时信息&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;过于宽泛的问题&lt;/strong&gt;：本身已经足够抽象&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;性能特点&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;准确性&lt;/strong&gt;：⭐⭐⭐⭐⭐（提供理论背景）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;速度&lt;/strong&gt;：⭐⭐⭐（需要两次检索）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;资源消耗&lt;/strong&gt;：⭐⭐⭐（中等）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实现复杂度&lt;/strong&gt;：⭐⭐（相对简单）&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;h2&gt;Part5: HyDE - 假设性文档嵌入&lt;/h2&gt;
&lt;h3&gt;5.1 核心概念&lt;/h3&gt;
&lt;p&gt;HyDE (Hypothetical Document Embeddings) 不直接检索用户查询，而是先让LLM生成一个&lt;strong&gt;假设性的答案文档&lt;/strong&gt;，然后用这个文档去检索相似内容。&lt;/p&gt;
&lt;h4&gt;HyDE 的直觉理解&lt;/h4&gt;
&lt;h5&gt;传统检索的问题&lt;/h5&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 传统检索
用户查询: &quot;什么是机器学习？&quot;
    ↓ 直接嵌入（查询向量稀疏）
查询向量: [0.1, 0.3, -0.2, ...]
    ↓ 检索
找到的文档（可能不相关）
&lt;/code&gt;&lt;/pre&gt;
&lt;h5&gt;HyDE 检索的优势&lt;/h5&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# HyDE检索
用户查询: &quot;什么是机器学习？&quot;
    ↓ LLM生成假设性答案
假设文档: &quot;机器学习是人工智能的一个分支，它使计算机能够从数据中学习并做出预测...&quot;
    ↓ 嵌入假设文档（文档向量丰富）
文档向量: [0.2, 0.4, -0.1, ...]  # 与真实答案文档更相似！
    ↓ 检索
找到更相关的文档
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;为什么有效？&lt;/h4&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;查询通常很短&lt;/strong&gt;，而&lt;strong&gt;文档内容丰富&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;假设性文档&lt;/strong&gt;比查询更接近真实文档的表达方式&lt;/li&gt;
&lt;li&gt;在&lt;strong&gt;语义空间&lt;/strong&gt;中，答案文档之间的相似度高于查询与文档的相似度&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;5.2 HyDE RAG 完整实现&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class HyDERAG:
    &quot;&quot;&quot;HyDE (Hypothetical Document Embeddings) RAG系统&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm, embeddings):
        self.vectorstore = vectorstore
        self.llm = llm
        self.embeddings = embeddings
        print(&quot;✅ HyDE RAG系统初始化完成&quot;)
    
    def generate_hypothetical_document(
        self, 
        question: str,
        style: str = &quot;academic&quot;
    ) -&gt; str:
        &quot;&quot;&quot;生成假设性文档&quot;&quot;&quot;
        if style == &quot;academic&quot;:
            system_msg = &quot;你是一位领域专家。请对以下问题写一段详细、准确的学术性回答（200-300字）。&quot;
        elif style == &quot;concise&quot;:
            system_msg = &quot;请对以下问题写一段简洁的回答（100-150字）。&quot;
        else:
            system_msg = &quot;请对以下问题写一段回答。&quot;
        
        # ChatML格式提示词
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            {system_msg}
            
            要求：
            1. 内容准确专业
            2. 结构清晰完整
            3. 使用恰当的术语&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            {question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=400,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs and outputs[0].outputs:
            hypothetical_doc = outputs[0].outputs[0].text.strip()
        else:
            hypothetical_doc = &quot;生成失败&quot;
        
        return hypothetical_doc
    
    def search_with_hyde(
        self, 
        question: str, 
        k: int = 5
    ) -&gt; list:
        &quot;&quot;&quot;使用HyDE进行检索&quot;&quot;&quot;
        print(f&quot;🔍 生成假设性文档...&quot;)
        
        # 1. 生成假设性文档
        hyp_doc = self.generate_hypothetical_document(question)
        print(f&quot;📝 假设性文档生成完成（{len(hyp_doc)}字符）:&quot;)
        print(f&quot;   {hyp_doc[:150]}...\n&quot;)
        
        # 2. 使用假设性文档检索（而不是原始查询）
        docs = self.vectorstore.similarity_search(hyp_doc, k=k)
        
        return docs
    
    def query(self, question: str, k: int = 5) -&gt; dict:
        &quot;&quot;&quot;执行HyDE RAG查询&quot;&quot;&quot;
        print(f&quot;❓ 原始问题: {question}&quot;)
        print(&quot;-&quot; * 50)
        
        # 1. HyDE检索
        docs = self.search_with_hyde(question, k=k)
        print(f&quot;📚 使用HyDE检索到 {len(docs)} 个相关文档&quot;)
        
        # 显示检索结果预览
        for i, doc in enumerate(docs, 1):
            preview = doc.page_content[:100] + &quot;...&quot; if len(doc.page_content) &gt; 100 else doc.page_content
            print(f&quot;   {i}. {preview}&quot;)
        
        # 2. 生成最终答案
        context = &quot;\n\n&quot;.join([doc.page_content for doc in docs])
        
        # ChatML格式提示词
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个专业的AI助手。请基于以下文档准确回答问题。
            
            文档内容：
            {context}
            
            请确保回答：
            1. 基于提供的文档信息
            2. 准确、完整、专业
            3. 如果信息不足，请说明&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            问题: {question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=500,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs and outputs[0].outputs:
            answer = outputs[0].outputs[0].text.strip()
        else:
            answer = &quot;生成失败&quot;
        
        return {
            &quot;question&quot;: question,
            &quot;hypothetical_doc_preview&quot;: self.generate_hypothetical_document(question, &quot;concise&quot;)[:200] + &quot;...&quot;,
            &quot;num_docs&quot;: len(docs),
            &quot;answer&quot;: answer
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;5.3 实战测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 使用示例
print(&quot;🚀 HyDE RAG系统&quot;)
print(&quot;=&quot; * 60)

hyde_rag = HyDERAG(vectorstore, llm, embeddings)
result = hyde_rag.query(&quot;解释滑动检测中的触觉传感器&quot;)

print(&quot;\n&quot; + &quot;=&quot;*60)
print(&quot;💡 最终答案:&quot;)
print(result[&quot;answer&quot;])

print(f&quot;\n📊 查询详情:&quot;)
print(f&quot;假设性文档预览: {result[&apos;hypothetical_doc_preview&apos;]}&quot;)
print(f&quot;参考文档数量: {result[&apos;num_docs&apos;]}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;实际输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🚀 HyDE RAG系统
============================================================
✅ HyDE RAG系统初始化完成
❓ 原始问题: 解释滑动检测中的触觉传感器
--------------------------------------------------
🔍 生成假设性文档...
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.22s/it]
📝 假设性文档生成完成（285字符）:
   滑动检测中的触觉传感器是一种用于检测物体表面滑动的传感器。它们通常由一个弹性膜片和一个敏感元件组成，当物体滑过膜片时，敏感元件会感受到物体的滑动，并将其转化为电信号。这些传感器可以用于各种应用，如机器人导航、汽车安全系统和工业自动化...

📚 使用HyDE检索到 5 个相关文档
   1. 触觉传感器是一种能够检测物体表面压力变化的传感器，当物体发生滑动时，触觉传感器会高频振动...
   2. 滑动检测是机器人夹爪抓取控制的核心技术之一，通过分析触觉信号的变化来判断物体是否发生滑动...
   3. 在机械臂抓取任务中，触觉传感器可以提供实时的力反馈信息，帮助系统调整抓取策略...
   4. 触觉传感器的数据可以通过机器学习算法进行分析，实现更准确的滑动检测和预测...
   5. 高质量的触觉传感器应该具有高灵敏度、快速响应和良好的稳定性等特点...
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.07it/s]

============================================================
💡 最终答案:
触觉传感器是一种能够检测物体表面压力变化的传感器，当物体发生滑动时，触觉传感器会高频振动，触觉传感器数据会发生变化，从而引起压力中心数值的变化，通过分析压力中心数值的变化可以获得物体的滑动情况。滑动检测是机器人夹爪抓取控制的核心。

📊 查询详情:
假设性文档预览: 滑动检测中的触觉传感器是一种用于检测物体表面滑动的传感器。它们通常由一个弹性膜片和一个敏感元件组成，当物体滑过膜片时，敏感元件会感受到物体的滑动，并将其转化为电信号...
参考文档数量: 5
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;5.4 HyDE vs 传统检索对比&lt;/h3&gt;
&lt;p&gt;| 维度 | 传统检索 | HyDE | 优势分析 |
|------|---------|------|----------|
| &lt;strong&gt;检索对象&lt;/strong&gt; | 用户查询 | 假设性答案 | HyDE使用更丰富的文档向量 |
| &lt;strong&gt;语义匹配&lt;/strong&gt; | 查询↔文档 | 答案↔答案 | 答案之间相似度更高 |
| &lt;strong&gt;查询长度敏感性&lt;/strong&gt; | 高 | 低 | HyDE对短查询更友好 |
| &lt;strong&gt;额外LLM调用&lt;/strong&gt; | 0 | 1次 | HyDE增加一次生成成本 |
| &lt;strong&gt;适用场景&lt;/strong&gt; | 清晰查询 | 复杂/技术性查询 | HyDE适合专业领域 |
| &lt;strong&gt;检索质量&lt;/strong&gt; | 中等 | 高 | HyDE找到更相关文档 |&lt;/p&gt;
&lt;h3&gt;5.5 技术选型指南&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def select_hyde_scenario(question: str) -&gt; bool:
    &quot;&quot;&quot;判断是否适合使用HyDE&quot;&quot;&quot;
    
    hyde_keywords = [&quot;解释&quot;, &quot;什么是&quot;, &quot;如何工作&quot;, &quot;原理&quot;, &quot;机制&quot;, &quot;技术&quot;]
    simple_keywords = [&quot;多少&quot;, &quot;何时&quot;, &quot;哪里&quot;, &quot;是谁&quot;, &quot;是否&quot;]
    
    # 检查问题类型
    has_hyde_indicator = any(keyword in question for keyword in hyde_keywords)
    has_simple_indicator = any(keyword in question for keyword in simple_keywords)
    
    # 判断问题复杂度
    word_count = len(question.split())
    is_complex = word_count &gt; 8 or has_hyde_indicator
    
    return is_complex and not has_simple_indicator

# 使用示例
questions = [
    &quot;解释神经网络的反向传播算法&quot;,  # 适合HyDE
    &quot;北京今天气温多少度&quot;,          # 不适合HyDE
    &quot;Transformer模型的工作原理&quot;,   # 适合HyDE
]

for q in questions:
    should_use_hyde = select_hyde_scenario(q)
    print(f&quot;问题: &apos;{q}&apos; -&gt; 使用HyDE: {should_use_hyde}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;Part6: 技术对比与选择指南&lt;/h2&gt;
&lt;h3&gt;6.1 综合技术对比&lt;/h3&gt;
&lt;p&gt;| 技术 | 召回率 | 准确率 | 速度 | 成本 | 复杂度 | 适用场景 |
|------|-------|-------|------|------|--------|----------|
| &lt;strong&gt;Multi-Query&lt;/strong&gt; | ⭐⭐⭐ | ⭐⭐ | ⭐⭐ | 💰💰 | ⭐ | 表达模糊的问题 |
| &lt;strong&gt;RAG-Fusion&lt;/strong&gt; | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐ | 💰💰💰 | ⭐⭐ | 需要高质量结果 |
| &lt;strong&gt;Decomposition&lt;/strong&gt; | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐ | 💰💰💰💰 | ⭐⭐⭐ | 复杂多步骤问题 |
| &lt;strong&gt;Step Back&lt;/strong&gt; | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐ | 💰💰 | ⭐⭐ | 需要背景知识 |
| &lt;strong&gt;HyDE&lt;/strong&gt; | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐ | 💰💰 | ⭐⭐ | 技术性强的专业问题 |&lt;/p&gt;
&lt;h3&gt;6.2 智能选择决策树&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class IntelligentRAGSelector:
    &quot;&quot;&quot;智能RAG技术选择器&quot;&quot;&quot;
    
    def __init__(self):
        self.techniques = {
            &quot;multi_query&quot;: &quot;Multi-Query&quot;,
            &quot;rag_fusion&quot;: &quot;RAG-Fusion&quot;, 
            &quot;decomposition&quot;: &quot;Query Decomposition&quot;,
            &quot;step_back&quot;: &quot;Step Back&quot;,
            &quot;hyde&quot;: &quot;HyDE&quot;
        }
    
    def analyze_question(self, question: str) -&gt; dict:
        &quot;&quot;&quot;分析问题特征&quot;&quot;&quot;
        analysis = {
            &quot;length&quot;: len(question.split()),
            &quot;has_comparison&quot;: &quot;比较&quot; in question or &quot;对比&quot; in question,
            &quot;has_steps&quot;: &quot;步骤&quot; in question or &quot;流程&quot; in question,
            &quot;has_explanation&quot;: &quot;解释&quot; in question or &quot;什么是&quot; in question,
            &quot;has_technical&quot;: &quot;原理&quot; in question or &quot;机制&quot; in question,
            &quot;is_simple_fact&quot;: &quot;多少&quot; in question or &quot;何时&quot; in question
        }
        return analysis
    
    def select_technique(self, question: str) -&gt; str:
        &quot;&quot;&quot;根据问题特征选择最佳技术&quot;&quot;&quot;
        analysis = self.analyze_question(question)
        
        if analysis[&quot;is_simple_fact&quot;]:
            return &quot;base&quot;  # 基础RAG
        
        elif analysis[&quot;has_comparison&quot;] and analysis[&quot;length&quot;] &gt; 10:
            return &quot;decomposition&quot;  # 复杂比较问题
        
        elif analysis[&quot;has_steps&quot;]:
            return &quot;decomposition&quot;  # 步骤性问题
        
        elif analysis[&quot;has_technical&quot;] and analysis[&quot;has_explanation&quot;]:
            return &quot;hyde&quot;  # 技术解释问题
        
        elif analysis[&quot;has_explanation&quot;] and analysis[&quot;length&quot;] &gt; 15:
            return &quot;step_back&quot;  # 需要背景知识的解释
        
        elif analysis[&quot;length&quot;] &amp;#x3C; 8:
            return &quot;multi_query&quot;  # 短而模糊的问题
        
        else:
            return &quot;rag_fusion&quot;  # 默认选择
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;6.3 组合使用策略&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class AdvancedRAG:
    &quot;&quot;&quot;组合多种技术的高级RAG系统&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm, embeddings):
        self.vectorstore = vectorstore
        self.llm = llm
        self.embeddings = embeddings
        self.selector = IntelligentRAGSelector()
        
        # 初始化各技术组件
        self.multi_query = MultiQueryRAG(vectorstore, llm)
        self.rag_fusion = RAGFusion(vectorstore, llm)
        self.decomposition = RecursiveDecomposition(vectorstore, llm)
        self.step_back = StepBackRAG(vectorstore, llm)
        self.hyde = HyDERAG(vectorstore, llm, embeddings)
    
    def intelligent_query(self, question: str) -&gt; dict:
        &quot;&quot;&quot;智能选择和组合技术&quot;&quot;&quot;
        technique = self.selector.select_technique(question)
        
        print(f&quot;🎯 问题分析: &apos;{question}&apos;&quot;)
        print(f&quot;🔧 选择技术: {technique}&quot;)
        
        if technique == &quot;base&quot;:
            # 基础RAG
            retriever = self.vectorstore.as_retriever()
            docs = retriever.get_relevant_documents(question)
            context = &quot;\n\n&quot;.join([doc.page_content for doc in docs[:3]])
            return self._generate_simple_answer(question, context)
        
        elif technique == &quot;multi_query&quot;:
            return self.multi_query.query(question)
        
        elif technique == &quot;rag_fusion&quot;:
            return self.rag_fusion.query(question)
        
        elif technique == &quot;decomposition&quot;:
            return self.decomposition.query(question)
        
        elif technique == &quot;step_back&quot;:
            return self.step_back.query(question)
        
        elif technique == &quot;hyde&quot;:
            return self.hyde.query(question)
    
    def _generate_simple_answer(self, question: str, context: str) -&gt; dict:
        &quot;&quot;&quot;生成简单答案&quot;&quot;&quot;
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            基于以下信息回答问题：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            信息：{context}
            问题：{question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=300,
            temperature=0.3,
            top_p=0.8
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        answer = outputs[0].outputs[0].text.strip() if outputs else &quot;生成失败&quot;
        
        return {
            &quot;question&quot;: question,
            &quot;answer&quot;: answer,
            &quot;technique&quot;: &quot;base&quot;,
            &quot;context_preview&quot;: context[:200] + &quot;...&quot; if len(context) &gt; 200 else context
        }
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;6.4 性能优化与最佳实践&lt;/h3&gt;
&lt;h4&gt;缓存策略&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from functools import lru_cache
import hashlib

class CachedRAG:
    &quot;&quot;&quot;带缓存的RAG优化器&quot;&quot;&quot;
    
    def __init__(self):
        self.query_cache = {}
        self.technique_cache = {}
    
    def get_cache_key(self, query: str, technique: str) -&gt; str:
        &quot;&quot;&quot;生成缓存键&quot;&quot;&quot;
        content = f&quot;{technique}:{query}&quot;
        return hashlib.md5(content.encode()).hexdigest()
    
    @lru_cache(maxsize=100)
    def cached_technique_selection(self, question: str) -&gt; str:
        &quot;&quot;&quot;缓存技术选择结果&quot;&quot;&quot;
        return self.selector.select_technique(question)
    
    def cached_query(self, question: str, technique: str) -&gt; dict:
        &quot;&quot;&quot;缓存查询结果&quot;&quot;&quot;
        cache_key = self.get_cache_key(question, technique)
        
        if cache_key in self.query_cache:
            print(&quot;💾 使用缓存结果&quot;)
            return self.query_cache[cache_key]
        
        # 执行查询
        result = getattr(self, technique).query(question)
        self.query_cache[cache_key] = result
        
        return result
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;并行处理优化&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import asyncio
from concurrent.futures import ThreadPoolExecutor

async def parallel_hyde_queries(questions: List[str], hyde_rag: HyDERAG) -&gt; List[dict]:
    &quot;&quot;&quot;并行执行多个HyDE查询&quot;&quot;&quot;
    loop = asyncio.get_event_loop()
    
    with ThreadPoolExecutor(max_workers=3) as executor:
        tasks = [
            loop.run_in_executor(executor, hyde_rag.query, question)
            for question in questions
        ]
        
        results = await asyncio.gather(*tasks)
    
    return results
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;自适应参数调整&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def adaptive_parameters(question: str) -&gt; dict:
    &quot;&quot;&quot;根据问题动态调整参数&quot;&quot;&quot;
    word_count = len(question.split())
    
    if word_count &amp;#x3C; 5:
        # 短查询需要更多文档
        return {&quot;k&quot;: 7, &quot;temperature&quot;: 0.2}
    elif word_count &gt; 20:
        # 长查询可以减少文档数
        return {&quot;k&quot;: 3, &quot;temperature&quot;: 0.1}
    elif &quot;比较&quot; in question or &quot;分析&quot; in question:
        # 分析类问题需要更多上下文
        return {&quot;k&quot;: 6, &quot;temperature&quot;: 0.3}
    else:
        return {&quot;k&quot;: 5, &quot;temperature&quot;: 0.3}  # 默认参数
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;h3&gt;技术综合评估&lt;/h3&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;Step Back Prompting&lt;/strong&gt;：适合需要&lt;strong&gt;理论背景&lt;/strong&gt;的复杂问题&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;HyDE&lt;/strong&gt;：在&lt;strong&gt;技术性、专业性&lt;/strong&gt;问题上表现优异&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;组合策略&lt;/strong&gt;：根据问题特征智能选择最优技术&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;🎯 实践建议&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;| 场景类型 | 推荐技术 | 原因 |
|---------|---------|------|
| &lt;strong&gt;学术技术问题&lt;/strong&gt; | HyDE + Step Back | 提供专业背景和理论支撑 |
| &lt;strong&gt;多角度分析&lt;/strong&gt; | RAG-Fusion | 综合多个查询视角 |
| &lt;strong&gt;步骤性指导&lt;/strong&gt; | Query Decomposition | 分解复杂流程 |
| &lt;strong&gt;实时简单查询&lt;/strong&gt; | 基础RAG | 快速响应，成本低 |&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;核心价值&lt;/strong&gt;：Step Back和HyDE等技术通过&lt;strong&gt;更智能的查询理解&lt;/strong&gt;和&lt;strong&gt;更精准的检索策略&lt;/strong&gt;，显著提升了RAG系统处理复杂专业问题的能力。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;这些高级查询优化技术让RAG系统从简单的文档检索升级为真正的&lt;strong&gt;智能知识助手&lt;/strong&gt;，能够处理各种复杂的信息需求。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part7: 实战案例：智能客服与学术助手&lt;/h2&gt;
&lt;h3&gt;7.1 案例1：智能客服助手系统&lt;/h3&gt;
&lt;h4&gt;7.1.1 系统架构设计&lt;/h4&gt;
&lt;p&gt;智能客服RAG系统通过&lt;strong&gt;智能路由机制&lt;/strong&gt;，根据用户问题类型自动选择最合适的检索增强生成技术，实现精准高效的客户服务。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import asyncio
from typing import Dict, Any, List
from vllm import SamplingParams

class CustomerServiceRAG:
    &quot;&quot;&quot;智能客服RAG系统 - 完整实现&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm, embeddings):
        &quot;&quot;&quot;
        初始化智能客服系统
        
        Args:
            vectorstore: 向量数据库
            llm: 语言模型
            embeddings: 嵌入模型
        &quot;&quot;&quot;
        self.vectorstore = vectorstore
        self.llm = llm
        self.embeddings = embeddings
        
        # 预定义场景关键词
        self.scenarios = {
            &quot;product_info&quot;: [&quot;产品&quot;, &quot;功能&quot;, &quot;特性&quot;, &quot;价格&quot;, &quot;规格&quot;, &quot;参数&quot;, &quot;型号&quot;],
            &quot;troubleshooting&quot;: [&quot;问题&quot;, &quot;错误&quot;, &quot;故障&quot;, &quot;不工作&quot;, &quot;无法&quot;, &quot;解决&quot;, &quot;修复&quot;],
            &quot;how_to&quot;: [&quot;如何&quot;, &quot;怎么&quot;, &quot;步骤&quot;, &quot;教程&quot;, &quot;操作&quot;, &quot;使用&quot;, &quot;安装&quot;],
            &quot;comparison&quot;: [&quot;比较&quot;, &quot;对比&quot;, &quot;哪个好&quot;, &quot;区别&quot;, &quot;差异&quot;, &quot;优缺点&quot;]
        }
        
        # 初始化各技术组件
        self._initialize_components()
        print(&quot;✅ 智能客服RAG系统初始化完成&quot;)
    
    def _initialize_components(self):
        &quot;&quot;&quot;初始化各技术组件&quot;&quot;&quot;
        from .multi_query import MultiQueryRAG
        from .rag_fusion import RAGFusion
        from .decomposition import RecursiveDecomposition
        from .step_back import StepBackRAG
        from .hyde import HyDERAG
        
        self.multi_query = MultiQueryRAG(self.vectorstore, self.llm)
        self.rag_fusion = RAGFusion(self.vectorstore, self.llm)
        self.decomposition = RecursiveDecomposition(self.vectorstore, self.llm)
        self.step_back = StepBackRAG(self.vectorstore, self.llm)
        self.hyde = HyDERAG(self.vectorstore, self.llm, self.embeddings)
    
    def classify_query(self, question: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;智能分类用户查询&quot;&quot;&quot;
        question_lower = question.lower()
        
        # 检测场景关键词
        matched_scenarios = []
        for scenario, keywords in self.scenarios.items():
            if any(keyword in question_lower for keyword in keywords):
                matched_scenarios.append(scenario)
        
        # 分析问题复杂度
        complexity = self._analyze_complexity(question)
        
        # 确定优先级
        if &quot;troubleshooting&quot; in matched_scenarios:
            primary_scenario = &quot;troubleshooting&quot;
        elif &quot;how_to&quot; in matched_scenarios:
            primary_scenario = &quot;how_to&quot;
        elif &quot;product_info&quot; in matched_scenarios:
            primary_scenario = &quot;product_info&quot;
        elif &quot;comparison&quot; in matched_scenarios:
            primary_scenario = &quot;comparison&quot;
        elif matched_scenarios:
            primary_scenario = matched_scenarios[0]
        else:
            primary_scenario = &quot;general&quot;
        
        return {
            &quot;primary_scenario&quot;: primary_scenario,
            &quot;matched_scenarios&quot;: matched_scenarios,
            &quot;complexity&quot;: complexity,
            &quot;word_count&quot;: len(question.split())
        }
    
    def _analyze_complexity(self, question: str) -&gt; str:
        &quot;&quot;&quot;分析问题复杂度&quot;&quot;&quot;
        word_count = len(question.split())
        
        if word_count &amp;#x3C;= 5:
            return &quot;simple&quot;
        elif word_count &amp;#x3C;= 10:
            return &quot;medium&quot;
        else:
            return &quot;complex&quot;
    
    def select_technique(self, classification: Dict[str, Any]) -&gt; str:
        &quot;&quot;&quot;根据分类结果选择最佳技术&quot;&quot;&quot;
        scenario = classification[&quot;primary_scenario&quot;]
        complexity = classification[&quot;complexity&quot;]
        
        technique_mapping = {
            &quot;product_info&quot;: {
                &quot;simple&quot;: &quot;multi_query&quot;,
                &quot;medium&quot;: &quot;rag_fusion&quot;, 
                &quot;complex&quot;: &quot;hyde&quot;
            },
            &quot;troubleshooting&quot;: {
                &quot;simple&quot;: &quot;multi_query&quot;,
                &quot;medium&quot;: &quot;decomposition&quot;,
                &quot;complex&quot;: &quot;decomposition&quot;
            },
            &quot;how_to&quot;: {
                &quot;simple&quot;: &quot;step_back&quot;,
                &quot;medium&quot;: &quot;step_back&quot;,
                &quot;complex&quot;: &quot;decomposition&quot;
            },
            &quot;comparison&quot;: {
                &quot;simple&quot;: &quot;multi_query&quot;,
                &quot;medium&quot;: &quot;rag_fusion&quot;,
                &quot;complex&quot;: &quot;decomposition&quot;
            },
            &quot;general&quot;: {
                &quot;simple&quot;: &quot;multi_query&quot;,
                &quot;medium&quot;: &quot;rag_fusion&quot;,
                &quot;complex&quot;: &quot;hyde&quot;
            }
        }
        
        return technique_mapping.get(scenario, {}).get(complexity, &quot;rag_fusion&quot;)
    
    def execute_technique(self, question: str, technique: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;执行选定的技术&quot;&quot;&quot;
        technique_handlers = {
            &quot;multi_query&quot;: self.multi_query.query,
            &quot;rag_fusion&quot;: self.rag_fusion.query,
            &quot;decomposition&quot;: self.decomposition.query,
            &quot;step_back&quot;: self.step_back.query,
            &quot;hyde&quot;: self.hyde.query
        }
        
        if technique in technique_handlers:
            return technique_handlersquestion
        else:
            # 默认使用RAG-Fusion
            return self.rag_fusion.query(question)
    
    def format_customer_response(self, result: Dict[str, Any], classification: Dict[str, Any]) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;格式化客服响应&quot;&quot;&quot;
        scenario = classification[&quot;primary_scenario&quot;]
        
        # 根据场景添加特定格式
        if scenario == &quot;troubleshooting&quot;:
            # 故障排查类问题：步骤化格式
            answer = self._format_troubleshooting_answer(result[&quot;answer&quot;])
        elif scenario == &quot;how_to&quot;:
            # 操作教程类问题：步骤化格式
            answer = self._format_how_to_answer(result[&quot;answer&quot;])
        elif scenario == &quot;product_info&quot;:
            # 产品信息类问题：结构化格式
            answer = self._format_product_info_answer(result[&quot;answer&quot;])
        else:
            answer = result[&quot;answer&quot;]
        
        return {
            &quot;question&quot;: result[&quot;question&quot;],
            &quot;answer&quot;: answer,
            &quot;technique_used&quot;: result.get(&quot;technique&quot;, &quot;unknown&quot;),
            &quot;scenario&quot;: scenario,
            &quot;confidence&quot;: self._calculate_confidence(result, classification),
            &quot;suggested_follow_up&quot;: self._suggest_follow_up(scenario),
            &quot;response_time&quot;: result.get(&quot;response_time&quot;, 0)
        }
    
    def _format_troubleshooting_answer(self, answer: str) -&gt; str:
        &quot;&quot;&quot;格式化故障排查答案&quot;&quot;&quot;
        # 添加故障排查模板
        troubleshooting_template = &quot;&quot;&quot;🔧 **故障排查指南**

				{content}
				
				💡 **温馨提示**：
				• 请按顺序尝试以上步骤
				• 如果问题仍未解决，请联系技术支持&quot;&quot;&quot;
        
        return troubleshooting_template.format(content=answer)
    
    def _format_how_to_answer(self, answer: str) -&gt; str:
        &quot;&quot;&quot;格式化操作教程答案&quot;&quot;&quot;
        # 添加操作步骤模板
        how_to_template = &quot;&quot;&quot;📋 **操作步骤指南**

				{content}
				
				✅ **完成检查**：
				• 确保每个步骤都正确执行
				• 如有疑问可参考详细文档&quot;&quot;&quot;
        
        return how_to_template.format(content=answer)
    
    def _format_product_info_answer(self, answer: str) -&gt; str:
        &quot;&quot;&quot;格式化产品信息答案&quot;&quot;&quot;
        # 添加产品信息模板
        product_template = &quot;&quot;&quot;📊 **产品信息摘要**

				{content}
				
				🔍 **更多信息**：
				• 查看详细规格手册
				• 联系销售获取报价&quot;&quot;&quot;
        
        return product_template.format(content=answer)
    
    def _calculate_confidence(self, result: Dict[str, Any], classification: Dict[str, Any]) -&gt; float:
        &quot;&quot;&quot;计算回答置信度&quot;&quot;&quot;
        base_confidence = 0.7
        
        # 根据文档数量调整置信度
        num_docs = result.get(&quot;num_docs&quot;, 0)
        if num_docs &gt;= 3:
            base_confidence += 0.2
        elif num_docs == 0:
            base_confidence -= 0.3
        
        # 根据问题复杂度调整
        if classification[&quot;complexity&quot;] == &quot;simple&quot;:
            base_confidence += 0.1
        elif classification[&quot;complexity&quot;] == &quot;complex&quot;:
            base_confidence -= 0.1
        
        return min(max(base_confidence, 0.1), 1.0)
    
    def _suggest_follow_up(self, scenario: str) -&gt; List[str]:
        &quot;&quot;&quot;根据场景推荐后续问题&quot;&quot;&quot;
        follow_up_questions = {
            &quot;product_info&quot;: [
                &quot;这个产品有哪些具体型号？&quot;,
                &quot;价格和保修政策是怎样的？&quot;,
                &quot;与其他产品相比有什么优势？&quot;
            ],
            &quot;troubleshooting&quot;: [
                &quot;如果以上方法无效该怎么办？&quot;,
                &quot;如何预防类似问题再次发生？&quot;,
                &quot;是否需要专业技术人员协助？&quot;
            ],
            &quot;how_to&quot;: [
                &quot;有没有视频教程可以参考？&quot;,
                &quot;常见错误和解决方法有哪些？&quot;,
                &quot;高级功能如何使用？&quot;
            ],
            &quot;general&quot;: [
                &quot;还有其他相关问题吗？&quot;,
                &quot;需要更详细的信息吗？&quot;,
                &quot;是否解决了您的问题？&quot;
            ]
        }
        
        return follow_up_questions.get(scenario, [])
    
    def handle_query(self, question: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;处理客服查询 - 主入口&quot;&quot;&quot;
        print(f&quot;👤 客户查询: {question}&quot;)
        print(&quot;-&quot; * 50)
        
        start_time = asyncio.get_event_loop().time()
        
        # 1. 智能分类
        classification = self.classify_query(question)
        print(f&quot;🔍 问题分类: {classification}&quot;)
        
        # 2. 选择技术
        technique = self.select_technique(classification)
        print(f&quot;🛠️ 选择技术: {technique}&quot;)
        
        # 3. 执行查询
        result = self.execute_technique(question, technique)
        
        # 4. 计算响应时间
        response_time = asyncio.get_event_loop().time() - start_time
        result[&quot;response_time&quot;] = response_time
        
        # 5. 格式化响应
        formatted_response = self.format_customer_response(result, classification)
        
        print(f&quot;⏱️ 响应时间: {response_time:.2f}s&quot;)
        print(f&quot;📊 置信度: {formatted_response[&apos;confidence&apos;]:.1%}&quot;)
        
        return formatted_response
    
    def batch_handle_queries(self, questions: List[str]) -&gt; List[Dict[str, Any]]:
        &quot;&quot;&quot;批量处理查询&quot;&quot;&quot;
        results = []
        for i, question in enumerate(questions, 1):
            print(f&quot;\n🔍 处理第 {i}/{len(questions)} 个查询...&quot;)
            result = self.handle_query(question)
            results.append(result)
        return results
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;7.1.2 实战测试示例&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# === 智能客服系统测试 ===

def test_customer_service():
    &quot;&quot;&quot;测试智能客服系统&quot;&quot;&quot;
    print(&quot;🚀 智能客服RAG系统测试&quot;)
    print(&quot;=&quot; * 60)
    
    # 初始化系统
    customer_service = CustomerServiceRAG(vectorstore, llm, embeddings)
    
    # 测试不同场景的问题
    test_questions = [
        # 产品信息类
        &quot;你们的最新智能机器人有哪些功能？&quot;,
        # 故障排查类  
        &quot;我的机器人无法启动，显示错误代码E102&quot;,
        # 操作教程类
        &quot;如何设置机器人的自动巡逻模式？&quot;,
        # 比较类
        &quot;旗舰款和标准款机器人有什么区别？&quot;,
        # 一般问题
        &quot;机器人的保修政策是怎样的？&quot;
    ]
    
    print(&quot;📋 测试问题列表:&quot;)
    for i, q in enumerate(test_questions, 1):
        print(f&quot;   {i}. {q}&quot;)
    
    print(&quot;\n&quot; + &quot;=&quot;*60)
    
    # 逐个处理问题
    results = []
    for question in test_questions:
        result = customer_service.handle_query(question)
        results.append(result)
        
        print(f&quot;\n💡 答案摘要:&quot;)
        print(f&quot;   场景: {result[&apos;scenario&apos;]}&quot;)
        print(f&quot;   技术: {result[&apos;technique_used&apos;]}&quot;)
        print(f&quot;   置信度: {result[&apos;confidence&apos;]:.1%}&quot;)
        print(f&quot;   答案预览: {result[&apos;answer&apos;][:100]}...&quot;)
        print(&quot;-&quot; * 40)
    
    return results

# 运行测试
if __name__ == &quot;__main__&quot;:
    results = test_customer_service()
    
    # 生成性能报告
    print(&quot;\n📊 性能统计报告:&quot;)
    total_time = sum(r[&quot;response_time&quot;] for r in results)
    avg_confidence = sum(r[&quot;confidence&quot;] for r in results) / len(results)
    
    print(f&quot;   总处理时间: {total_time:.2f}s&quot;)
    print(f&quot;   平均置信度: {avg_confidence:.1%}&quot;)
    print(f&quot;   处理问题数: {len(results)}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;7.1.3 场景处理详情&lt;/h4&gt;
&lt;h5&gt;🔧 故障排查场景处理流程&lt;/h5&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def troubleshoot_example():
    &quot;&quot;&quot;故障排查场景示例&quot;&quot;&quot;
    question = &quot;机器人抓取物体时经常滑落，如何解决？&quot;
    
    customer_service = CustomerServiceRAG(vectorstore, llm, embeddings)
    result = customer_service.handle_query(question)
    
    # 显示详细处理过程
    print(&quot;🔧 故障排查场景处理详情:&quot;)
    print(f&quot;   原始问题: {result[&apos;question&apos;]}&quot;)
    print(f&quot;   识别场景: {result[&apos;scenario&apos;]}&quot;)
    print(f&quot;   使用技术: {result[&apos;technique_used&apos;]}&quot;)
    print(f&quot;   响应格式: 步骤化故障排查指南&quot;)
    print(f&quot;   推荐后续问题: {result[&apos;suggested_follow_up&apos;]}&quot;)
    
    return result
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;预期输出&lt;/strong&gt;：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;👤 客户查询: 机器人抓取物体时经常滑落，如何解决？
--------------------------------------------------
🔍 问题分类: {&apos;primary_scenario&apos;: &apos;troubleshooting&apos;, &apos;matched_scenarios&apos;: [&apos;troubleshooting&apos;], &apos;complexity&apos;: &apos;complex&apos;, &apos;word_count&apos;: 8}
🛠️ 选择技术: decomposition
⏱️ 响应时间: 2.34s
📊 置信度: 85.0%

💡 答案摘要:
   场景: troubleshooting
   技术: decomposition  
   置信度: 85.0%
   答案预览: 🔧 **故障排查指南** 1. 检查抓取力设置：确保抓取力足够但不过大 2. 验证物体表面：光滑表面可能需要特殊抓取策略...
&lt;/code&gt;&lt;/pre&gt;
&lt;h5&gt;📚 产品信息场景处理流程&lt;/h5&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def product_info_example():
    &quot;&quot;&quot;产品信息场景示例&quot;&quot;&quot;
    question = &quot;智能机械臂的最大负载是多少？&quot;
    
    customer_service = CustomerServiceRAG(vectorstore, llm, embeddings)
    result = customer_service.handle_query(question)
    
    print(&quot;📚 产品信息场景处理详情:&quot;)
    print(f&quot;   原始问题: {result[&apos;question&apos;]}&quot;)
    print(f&quot;   识别场景: {result[&apos;scenario&apos;]}&quot;) 
    print(f&quot;   使用技术: {result[&apos;technique_used&apos;]}&quot;)
    print(f&quot;   响应格式: 结构化产品信息&quot;)
    
    return result
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;7.2 案例2：学术论文助手系统&lt;/h3&gt;
&lt;h4&gt;7.2.1 系统架构设计&lt;/h4&gt;
&lt;p&gt;学术论文助手专门为研究人员设计，集成多种RAG技术来支持学术工作的各个环节。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class AcademicRAG:
    &quot;&quot;&quot;学术论文助手系统 - 完整实现&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm, embeddings):
        &quot;&quot;&quot;
        初始化学术论文助手
        
        Args:
            vectorstore: 包含学术论文的向量数据库
            llm: 语言模型
            embeddings: 嵌入模型
        &quot;&quot;&quot;
        self.vectorstore = vectorstore
        self.llm = llm
        self.embeddings = embeddings
        
        # 初始化各技术组件
        self._initialize_components()
        
        # 学术领域关键词
        self.academic_domains = {
            &quot;computer_science&quot;: [&quot;机器学习&quot;, &quot;深度学习&quot;, &quot;算法&quot;, &quot;神经网络&quot;, &quot;AI&quot;],
            &quot;robotics&quot;: [&quot;机器人&quot;, &quot;机械臂&quot;, &quot;抓取&quot;, &quot;运动规划&quot;, &quot;SLAM&quot;],
            &quot;nlp&quot;: [&quot;自然语言处理&quot;, &quot;文本生成&quot;, &quot;语义分析&quot;, &quot;词向量&quot;],
            &quot;vision&quot;: [&quot;计算机视觉&quot;, &quot;图像识别&quot;, &quot;目标检测&quot;, &quot;分割&quot;]
        }
        
        print(&quot;✅ 学术论文助手初始化完成&quot;)
    
    def _initialize_components(self):
        &quot;&quot;&quot;初始化各技术组件&quot;&quot;&quot;
        from .multi_query import MultiQueryRAG
        from .hyde import HyDERAG
        from .decomposition import RecursiveDecomposition
        from .step_back import StepBackRAG
        from .rag_fusion import RAGFusion
        
        self.multi_query = MultiQueryRAG(self.vectorstore, self.llm)
        self.hyde = HyDERAG(self.vectorstore, self.llm, self.embeddings)
        self.decomposition = RecursiveDecomposition(self.vectorstore, self.llm)
        self.step_back = StepBackRAG(self.vectorstore, self.llm)
        self.rag_fusion = RAGFusion(self.vectorstore, self.llm)
    
    def identify_academic_domain(self, topic: str) -&gt; str:
        &quot;&quot;&quot;识别学术领域&quot;&quot;&quot;
        topic_lower = topic.lower()
        
        for domain, keywords in self.academic_domains.items():
            if any(keyword in topic_lower for keyword in keywords):
                return domain
        
        return &quot;general&quot;
    
    def literature_review(self, topic: str, depth: str = &quot;comprehensive&quot;) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;
        生成文献综述
        
        Args:
            topic: 研究主题
            depth: 综述深度 (&quot;brief&quot; | &quot;standard&quot; | &quot;comprehensive&quot;)
        &quot;&quot;&quot;
        print(f&quot;📚 生成文献综述: {topic}&quot;)
        print(f&quot;   深度级别: {depth}&quot;)
        
        domain = self.identify_academic_domain(topic)
        print(f&quot;   识别领域: {domain}&quot;)
        
        # 根据深度调整参数
        depth_params = {
            &quot;brief&quot;: {&quot;k&quot;: 3, &quot;max_tokens&quot;: 800},
            &quot;standard&quot;: {&quot;k&quot;: 5, &quot;max_tokens&quot;: 1200},
            &quot;comprehensive&quot;: {&quot;k&quot;: 8, &quot;max_tokens&quot;: 2000}
        }
        params = depth_params.get(depth, depth_params[&quot;standard&quot;])
        
        # 1. Multi-Query广泛检索研究现状
        broad_queries = [
            f&quot;{topic}的研究现状&quot;,
            f&quot;{topic}的最新进展&quot;,
            f&quot;{topic}的主要挑战&quot;,
            f&quot;{topic}的未来方向&quot;
        ]
        
        print(&quot;🔍 广泛检索研究现状...&quot;)
        broad_results = []
        for query in broad_queries:
            result = self.multi_query.query(query, k=params[&quot;k&quot;]//2)
            broad_results.extend(result.get(&quot;documents&quot;, []))
        
        # 2. HyDE深入检索理论基础
        print(&quot;🔍 深入检索理论基础...&quot;)
        theory_queries = [
            f&quot;{topic}的理论基础&quot;,
            f&quot;{topic}的核心概念&quot;,
            f&quot;{topic}的方法论&quot;
        ]
        
        theory_results = []
        for query in theory_queries:
            result = self.hyde.query(query, k=params[&quot;k&quot;]//2)
            theory_results.extend(result.get(&quot;documents&quot;, []))
        
        # 3. 去重和合并
        all_docs = self._deduplicate_documents(broad_results + theory_results)
        selected_docs = all_docs[:params[&quot;k&quot;]]
        
        print(f&quot;📄 最终选择 {len(selected_docs)} 篇相关文献&quot;)
        
        # 4. 生成文献综述
        review = self._generate_literature_review(topic, selected_docs, params[&quot;max_tokens&quot;])
        
        return {
            &quot;topic&quot;: topic,
            &quot;domain&quot;: domain,
            &quot;depth&quot;: depth,
            &quot;num_documents&quot;: len(selected_docs),
            &quot;review&quot;: review,
            &quot;documents_preview&quot;: [doc.page_content[:200] for doc in selected_docs[:3]]
        }
    
    def _generate_literature_review(self, topic: str, documents: List, max_tokens: int) -&gt; str:
        &quot;&quot;&quot;生成文献综述内容&quot;&quot;&quot;
        context = &quot;\n\n&quot;.join([doc.page_content for doc in documents])
        
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一位学术研究助手，需要基于提供的文献生成一篇专业的文献综述。
            
            请按照以下结构组织内容：
            1. 研究背景和意义
            2. 理论基础和关键概念
            3. 主要研究方法和进展
            4. 当前挑战和局限性
            5. 未来研究方向
            
            要求：
            • 学术严谨，引用提供的文献内容
            • 逻辑清晰，结构完整
            • 语言专业但不晦涩&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            研究主题：{topic}
            
            相关文献：
            {context}
            
            请生成一篇关于{topic}的文献综述：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=max_tokens,
            temperature=0.3,
            top_p=0.8,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        return outputs[0].outputs[0].text.strip() if outputs else &quot;生成失败&quot;
    
    def compare_methods(self, method1: str, method2: str, aspect: str = &quot;all&quot;) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;
        比较两种方法
        
        Args:
            method1: 方法1名称
            method2: 方法2名称
            aspect: 比较方面 (&quot;all&quot; | &quot;performance&quot; | &quot;complexity&quot; | &quot;applicability&quot;)
        &quot;&quot;&quot;
        print(f&quot;🔬 方法比较: {method1} vs {method2}&quot;)
        print(f&quot;   比较方面: {aspect}&quot;)
        
        # 使用查询分解进行系统比较
        comparison_query = f&quot;系统比较{method1}和{method2}在性能、复杂度和适用性方面的差异&quot;
        
        result = self.decomposition.query(comparison_query)
        
        # 根据比较方面过滤内容
        if aspect != &quot;all&quot;:
            filtered_answer = self._filter_comparison_by_aspect(result[&quot;answer&quot;], aspect)
        else:
            filtered_answer = result[&quot;answer&quot;]
        
        return {
            &quot;method1&quot;: method1,
            &quot;method2&quot;: method2,
            &quot;aspect&quot;: aspect,
            &quot;comparison&quot;: filtered_answer,
            &quot;sub_questions&quot;: result.get(&quot;sub_questions&quot;, []),
            &quot;sub_answers&quot;: result.get(&quot;sub_answers&quot;, [])
        }
    
    def _filter_comparison_by_aspect(self, answer: str, aspect: str) -&gt; str:
        &quot;&quot;&quot;根据方面过滤比较内容&quot;&quot;&quot;
        aspect_keywords = {
            &quot;performance&quot;: [&quot;性能&quot;, &quot;准确率&quot;, &quot;效率&quot;, &quot;速度&quot;, &quot;效果&quot;],
            &quot;complexity&quot;: [&quot;复杂度&quot;, &quot;计算量&quot;, &quot;参数&quot;, &quot;训练时间&quot;, &quot;推理时间&quot;],
            &quot;applicability&quot;: [&quot;适用性&quot;, &quot;应用场景&quot;, &quot;限制&quot;, &quot;条件&quot;, &quot;领域&quot;]
        }
        
        keywords = aspect_keywords.get(aspect, [])
        lines = answer.split(&apos;\n&apos;)
        filtered_lines = [line for line in lines if any(kw in line for kw in keywords)]
        
        return &apos;\n&apos;.join(filtered_lines) if filtered_lines else answer
    
    def research_gap_analysis(self, topic: str) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;分析研究空白&quot;&quot;&quot;
        print(f&quot;🎯 研究空白分析: {topic}&quot;)
        
        # 使用Step Back技术分析研究现状和未来方向
        step_back_query = f&quot;{topic}领域的研究现状、主要成就和未来挑战&quot;
        
        result = self.step_back.query(step_back_query)
        
        # 生成研究空白分析
        gap_analysis = self._generate_gap_analysis(topic, result[&quot;answer&quot;])
        
        return {
            &quot;topic&quot;: topic,
            &quot;current_status&quot;: result[&quot;answer&quot;],
            &quot;research_gaps&quot;: gap_analysis,
            &quot;suggested_directions&quot;: self._suggest_research_directions(topic)
        }
    
    def _generate_gap_analysis(self, topic: str, current_status: str) -&gt; str:
        &quot;&quot;&quot;生成研究空白分析&quot;&quot;&quot;
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            基于当前研究现状，分析存在的研究空白和未来机会。
            
            请从以下角度分析：
            1. 理论方面的空白
            2. 方法技术的局限
            3. 应用场景的扩展
            4. 跨学科融合的机会&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            研究领域：{topic}
            
            当前研究现状：
            {current_status}
            
            请分析该领域的研究空白：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=800,
            temperature=0.4,
            top_p=0.8
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        return outputs[0].outputs[0].text.strip() if outputs else &quot;分析失败&quot;
    
    def _suggest_research_directions(self, topic: str) -&gt; List[str]:
        &quot;&quot;&quot;推荐研究方向&quot;&quot;&quot;
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            为以下研究领域推荐具体的研究方向，每行一个方向：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            研究领域：{topic}
            
            请推荐3-5个具体的研究方向：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            max_tokens=300,
            temperature=0.5,
            top_p=0.9
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        if outputs:
            text = outputs[0].outputs[0].text.strip()
            return [line.strip() for line in text.split(&apos;\n&apos;) if line.strip()]
        return []
    
    def _deduplicate_documents(self, documents: List) -&gt; List:
        &quot;&quot;&quot;去重文档&quot;&quot;&quot;
        seen_content = set()
        unique_docs = []
        
        for doc in documents:
            content_hash = hash(doc.page_content[:500])  # 前500字符作为标识
            if content_hash not in seen_content:
                seen_content.add(content_hash)
                unique_docs.append(doc)
        
        return unique_docs
    
    def generate_citation(self, topic: str, style: str = &quot;apa&quot;) -&gt; Dict[str, Any]:
        &quot;&quot;&quot;生成参考文献引用&quot;&quot;&quot;
        # 检索相关文献
        result = self.hyde.query(f&quot;{topic}的相关研究文献&quot;, k=5)
        
        # 生成引用格式
        citations = self._format_citations(result.get(&quot;documents&quot;, []), style)
        
        return {
            &quot;topic&quot;: topic,
            &quot;citation_style&quot;: style,
            &quot;citations&quot;: citations,
            &quot;num_references&quot;: len(citations)
        }
    
    def _format_citations(self, documents: List, style: str) -&gt; List[str]:
        &quot;&quot;&quot;格式化引用&quot;&quot;&quot;
        citations = []
        for i, doc in enumerate(documents, 1):
            # 简化版引用生成，实际应使用专业引用库
            metadata = doc.metadata
            title = metadata.get(&apos;title&apos;, &apos;未知标题&apos;)
            authors = metadata.get(&apos;authors&apos;, &apos;未知作者&apos;)
            year = metadata.get(&apos;year&apos;, &apos;未知年份&apos;)
            
            if style == &quot;apa&quot;:
                citation = f&quot;{authors} ({year}). {title}.&quot;
            elif style == &quot;mla&quot;:
                citation = f&quot;{authors}. \&quot;{title}\&quot;. {year}.&quot;
            else:
                citation = f&quot;{authors}. {title}. {year}.&quot;
            
            citations.append(f&quot;{i}. {citation}&quot;)
        
        return citations
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;7.2.2 学术助手实战测试&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def test_academic_assistant():
    &quot;&quot;&quot;测试学术论文助手&quot;&quot;&quot;
    print(&quot;🎓 学术论文助手测试&quot;)
    print(&quot;=&quot; * 60)
    
    # 初始化学术助手
    academic_rag = AcademicRAG(vectorstore, llm, embeddings)
    
    # 测试文献综述功能
    print(&quot;1. 📚 文献综述生成测试&quot;)
    review_result = academic_rag.literature_review(
        &quot;机器人抓取中的滑动检测技术&quot;, 
        depth=&quot;standard&quot;
    )
    
    print(f&quot;   生成综述长度: {len(review_result[&apos;review&apos;])} 字符&quot;)
    print(f&quot;   使用文献数: {review_result[&apos;num_documents&apos;]}&quot;)
    print(f&quot;   识别领域: {review_result[&apos;domain&apos;]}&quot;)
    
    # 测试方法比较功能
    print(&quot;\n2. 🔬 方法比较测试&quot;)
    comparison_result = academic_rag.compare_methods(
        &quot;基于视觉的抓取检测&quot;, 
        &quot;基于触觉的抓取检测&quot;,
        aspect=&quot;performance&quot;
    )
    
    print(f&quot;   比较方面: {comparison_result[&apos;aspect&apos;]}&quot;)
    print(f&quot;   生成子问题数: {len(comparison_result[&apos;sub_questions&apos;])}&quot;)
    
    # 测试研究空白分析
    print(&quot;\n3. 🎯 研究空白分析测试&quot;)
    gap_result = academic_rag.research_gap_analysis(&quot;仿人机器人抓取技术&quot;)
    
    print(f&quot;   分析主题: {gap_result[&apos;topic&apos;]}&quot;)
    print(f&quot;   推荐方向数: {len(gap_result[&apos;suggested_directions&apos;])}&quot;)
    
    # 测试引用生成
    print(&quot;\n4. 📖 参考文献生成测试&quot;)
    citation_result = academic_rag.generate_citation(&quot;深度学习在机器人中的应用&quot;, &quot;apa&quot;)
    
    print(f&quot;   引用风格: {citation_result[&apos;citation_style&apos;]}&quot;)
    print(f&quot;   生成引用数: {citation_result[&apos;num_references&apos;]}&quot;)
    
    return {
        &quot;literature_review&quot;: review_result,
        &quot;method_comparison&quot;: comparison_result,
        &quot;gap_analysis&quot;: gap_result,
        &quot;citations&quot;: citation_result
    }

# 运行测试
if __name__ == &quot;__main__&quot;:
    results = test_academic_assistant()
    
    # 显示详细结果
    print(&quot;\n&quot; + &quot;=&quot;*60)
    print(&quot;📊 学术助手测试结果摘要:&quot;)
    
    review = results[&quot;literature_review&quot;]
    print(f&quot;📚 文献综述: {review[&apos;topic&apos;]}&quot;)
    print(f&quot;   领域: {review[&apos;domain&apos;]}&quot;)
    print(f&quot;   深度: {review[&apos;depth&apos;]}&quot;)
    print(f&quot;   文献数: {review[&apos;num_documents&apos;]}&quot;)
    print(f&quot;   预览: {review[&apos;review&apos;][:200]}...&quot;)
    
    comparison = results[&quot;method_comparison&quot;]
    print(f&quot;\n🔬 方法比较: {comparison[&apos;method1&apos;]} vs {comparison[&apos;method2&apos;]}&quot;)
    print(f&quot;   方面: {comparison[&apos;aspect&apos;]}&quot;)
    print(f&quot;   子问题: {len(comparison[&apos;sub_questions&apos;])}个&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h4&gt;7.2.3 典型学术工作流&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def complete_research_workflow():
    &quot;&quot;&quot;完整学术研究工作流示例&quot;&quot;&quot;
    topic = &quot;多模态融合在机器人抓取中的应用&quot;
    
    academic_rag = AcademicRAG(vectorstore, llm, embeddings)
    
    print(f&quot;🎓 完整学术研究工作流: {topic}&quot;)
    print(&quot;=&quot; * 60)
    
    # 1. 文献综述
    print(&quot;1. 📚 生成文献综述...&quot;)
    review = academic_rag.literature_review(topic, &quot;comprehensive&quot;)
    
    # 2. 方法比较
    print(&quot;2. 🔬 比较相关方法...&quot;)
    comparison = academic_rag.compare_methods(&quot;视觉-触觉融合&quot;, &quot;纯视觉方法&quot;)
    
    # 3. 研究空白分析
    print(&quot;3. 🎯 分析研究空白...&quot;)
    gaps = academic_rag.research_gap_analysis(topic)
    
    # 4. 生成参考文献
    print(&quot;4. 📖 生成参考文献...&quot;)
    citations = academic_rag.generate_citation(topic, &quot;apa&quot;)
    
    # 整合结果
    research_report = {
        &quot;topic&quot;: topic,
        &quot;literature_review&quot;: review[&quot;review&quot;],
        &quot;method_comparison&quot;: comparison[&quot;comparison&quot;],
        &quot;research_gaps&quot;: gaps[&quot;research_gaps&quot;],
        &quot;suggested_directions&quot;: gaps[&quot;suggested_directions&quot;],
        &quot;citations&quot;: citations[&quot;citations&quot;]
    }
    
    print(f&quot;✅ 研究工作流完成!&quot;)
    print(f&quot;   文献综述长度: {len(review[&apos;review&apos;])} 字符&quot;)
    print(f&quot;   识别研究空白: {len(gaps[&apos;research_gaps&apos;].split(&apos;.&apos;))} 个主要方向&quot;)
    print(f&quot;   生成参考文献: {len(citations[&apos;citations&apos;])} 篇&quot;)
    
    return research_report
&lt;/code&gt;&lt;/pre&gt;
&lt;hr&gt;
&lt;h3&gt;7.3 案例对比与总结&lt;/h3&gt;
&lt;h4&gt;7.3.1 系统特性对比&lt;/h4&gt;
&lt;p&gt;| 特性 | 智能客服助手 | 学术论文助手 |
|------|------------|-------------|
| &lt;strong&gt;目标用户&lt;/strong&gt; | 普通用户、客户 | 研究人员、学生 |
| &lt;strong&gt;问题类型&lt;/strong&gt; | 产品、故障、操作 | 研究、分析、综述 |
| &lt;strong&gt;技术重点&lt;/strong&gt; | 场景识别、路由 | 深度分析、综合 |
| &lt;strong&gt;响应格式&lt;/strong&gt; | 客服模板、步骤化 | 学术结构、引用 |
| &lt;strong&gt;性能要求&lt;/strong&gt; | 实时响应、高可用 | 深度分析、准确性 |&lt;/p&gt;
&lt;h4&gt;7.3.2 技术应用对比&lt;/h4&gt;
&lt;p&gt;| 技术 | 智能客服应用 | 学术助手应用 |
|------|------------|-------------|
| &lt;strong&gt;Multi-Query&lt;/strong&gt; | 处理模糊客户问题 | 广泛检索研究现状 |
| &lt;strong&gt;RAG-Fusion&lt;/strong&gt; | 一般问题综合回答 | 多角度文献分析 |
| &lt;strong&gt;Query Decomposition&lt;/strong&gt; | 故障排查步骤化 | 方法系统比较 |
| &lt;strong&gt;Step Back&lt;/strong&gt; | 操作教程背景知识 | 研究现状分析 |
| &lt;strong&gt;HyDE&lt;/strong&gt; | 产品技术规格查询 | 理论深度检索 |&lt;/p&gt;
&lt;h4&gt;7.3.3 最佳实践总结&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;✅ 智能客服助手关键点&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;场景识别准确性&lt;/strong&gt;：决定技术选择的关键&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;响应模板化&lt;/strong&gt;：提升用户体验和专业性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实时性能&lt;/strong&gt;：确保客服对话流畅性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;置信度评估&lt;/strong&gt;：提供透明的质量指示&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;✅ 学术论文助手关键点&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;深度分析能力&lt;/strong&gt;：支持复杂研究需求&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;学术规范性&lt;/strong&gt;：符合学术写作标准&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;文献处理能力&lt;/strong&gt;：高效处理大量学术文献&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;研究方向建议&lt;/strong&gt;：提供创新性洞察&lt;/li&gt;
&lt;/ol&gt;
&lt;h4&gt;7.3.4 扩展建议&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;🔧 智能客服扩展方向&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;多轮对话支持&lt;/strong&gt;：处理复杂客服场景&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;情感分析&lt;/strong&gt;：识别用户情绪调整响应&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多语言支持&lt;/strong&gt;：国际化客服需求&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;知识库更新&lt;/strong&gt;：动态更新产品信息&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;🔧 学术助手扩展方向&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;专业领域定制&lt;/strong&gt;：针对特定学科优化&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;合作网络分析&lt;/strong&gt;：分析学者合作关系&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;趋势预测&lt;/strong&gt;：预测研究热点方向&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;论文写作辅助&lt;/strong&gt;：完整写作流程支持&lt;/li&gt;
&lt;/ul&gt;
&lt;blockquote&gt;
&lt;p&gt;💡 &lt;strong&gt;核心价值&lt;/strong&gt;：这两个案例展示了RAG技术在不同领域的强大应用潜力，通过&lt;strong&gt;智能技术选择&lt;/strong&gt;和&lt;strong&gt;场景化优化&lt;/strong&gt;，可以构建出高度专业化的智能助手系统。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;通过这两个实战案例，我们看到了高级RAG技术在实际应用中的巨大价值，为构建专业领域的智能助手提供了完整的技术框架和实践指南。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>RAG实战（四）高级索引与检索策略</title><link>https://astro-pure.js.org/blog/rag_blogs/rag_blogs-4</link><guid isPermaLink="true">https://astro-pure.js.org/blog/rag_blogs/rag_blogs-4</guid><description>记录RAG的学习。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/RAG_Learning&quot;&gt;Github地址&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;RAG高级索引与检索策略：提升检索质量的关键&lt;/h2&gt;
&lt;p&gt;在前面的章节中，我们学习了查询优化和路由技术。但是，检索质量不仅取决于查询，还取决于如何组织和索引文档。本章将深入探讨高级索引和检索策略。&lt;/p&gt;
&lt;h3&gt;为什么需要高级索引？&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;基础索引的局限性&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;问题1: 文档过长&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;一个10000字的文档被整体嵌入&lt;/li&gt;
&lt;li&gt;语义信息过于粗糙&lt;/li&gt;
&lt;li&gt;检索不精确&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;问题2: 上下文丢失&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;将文档分成小块&lt;/li&gt;
&lt;li&gt;每块独立检索&lt;/li&gt;
&lt;li&gt;丢失了块与块之间的关系&lt;/li&gt;
&lt;li&gt;无法理解完整上下文&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;问题3: 多语义内容&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;一个文档包含多个主题&lt;/li&gt;
&lt;li&gt;单一向量无法表示所有语义&lt;/li&gt;
&lt;li&gt;相关内容可能被遗漏&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;问题4: 检索冗余&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;检索到大量文档&lt;/li&gt;
&lt;li&gt;很多内容重复或不相关&lt;/li&gt;
&lt;li&gt;影响最终答案质量&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;本章技术概览&lt;/h3&gt;
&lt;p&gt;| 技术 | 核心目标 | 适用场景 | 复杂度 |
|------|----------|----------|---------|
| 智能分块 | 优化块大小和边界 | 所有RAG系统 | ⭐ |
| 多向量索引 | 一个文档多个向量 | 多主题文档 | ⭐⭐⭐ |
| 父文档检索 | 检索小块返回大块 | 保持上下文 | ⭐⭐ |
| 上下文压缩 | 压缩检索结果 | 减少冗余 | ⭐⭐⭐ |
| 时间衰减检索 | 考虑文档新鲜度 | 时效性内容 | ⭐⭐ |&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 1: 智能文档分块策略&lt;/h2&gt;
&lt;h3&gt;1.1 核心概念&lt;/h3&gt;
&lt;p&gt;文档分块（Chunking）是将长文档切分成更小片段的过程。好的分块策略能显著提升检索质量。&lt;/p&gt;
&lt;h3&gt;1.2 分块方法对比&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;方法1: 固定长度分块&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def fixed_length_split(text: str, chunk_size: int = 500) -&gt; List[str]:
    &quot;&quot;&quot;简单但可能切断语义&quot;&quot;&quot;
    return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;方法2: 句子级分块&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def sentence_split(text: str, max_sentences: int = 5) -&gt; List[str]:
    &quot;&quot;&quot;保持语义完整但长度不均&quot;&quot;&quot;
    sentences = text.split(&apos;。&apos;)
    chunks = []
    current_chunk = []
    
    for sent in sentences:
        current_chunk.append(sent)
        if len(current_chunk) &gt;= max_sentences:
            chunks.append(&apos;。&apos;.join(current_chunk))
            current_chunk = []
    
    if current_chunk:
        chunks.append(&apos;。&apos;.join(current_chunk))
    
    return chunks
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;方法3: 语义分块 ✅ 推荐&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def semantic_split(text: str, embeddings) -&gt; List[str]:
    &quot;&quot;&quot;基于语义相似度分块&quot;&quot;&quot;
    # 将在下面实现
    pass
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.3 普通文本分割器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.text_splitter import RecursiveCharacterTextSplitter

# 创建文本分割器
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,        # 每块的目标大小
    chunk_overlap=200,      # 块之间的重叠
    length_function=len,    # 长度计算函数
    separators=[            # 分隔符优先级
        &quot;\n\n&quot;,             # 段落
        &quot;\n&quot;,               # 行
        &quot; &quot;,                # 单词
        &quot;&quot;                  # 字符
    ]
)

# 示例文档
document = &quot;&quot;&quot;
# Python编程基础

## 变量和数据类型

Python是一种动态类型语言，这意味着你不需要声明变量的类型。
Python支持多种数据类型，包括整数、浮点数、字符串等。

## 控制流

Python使用缩进来定义代码块。
if语句用于条件判断，for循环用于迭代。

## 函数

函数是可重用的代码块。
使用def关键字定义函数。
&quot;&quot;&quot;

# 分块
chunks = text_splitter.split_text(document)

print(f&quot;文档被分成 {len(chunks)} 块&quot;)
for i, chunk in enumerate(chunks):
    print(f&quot;\n块 {i+1}:&quot;)
    print(chunk[:100] + &quot;...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;文档被分成 1 块

块 1:
# Python编程基础

## 变量和数据类型

Python是一种动态类型语言，这意味着你不需要声明变量的类型。
Python支持多种数据类型，包括整数、浮点数、字符串等。

## 控制流

Py...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.4 语义分块&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;处理流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入文本
    ↓
分割句子（中文优化）
    ↓
计算每个句子的嵌入
    ↓
计算相邻句子的相似度
    ↓
在相似度低的地方切分
    ↓
输出语义分块结果
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain_community.embeddings import HuggingFaceEmbeddings
import numpy as np
from typing import List
import re

class SemanticChunker:
    &quot;&quot;&quot;基于语义相似度的分块器（适配本地模型）&quot;&quot;&quot;
    
    def __init__(self, embeddings, similarity_threshold: float = 0.6):
        self.embeddings = embeddings
        self.similarity_threshold = similarity_threshold
    
    def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -&gt; float:
        &quot;&quot;&quot;计算余弦相似度&quot;&quot;&quot;
        if np.linalg.norm(vec1) == 0 or np.linalg.norm(vec2) == 0:
            return 0.0
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    
    def _split_sentences(self, text: str) -&gt; List[str]:
        &quot;&quot;&quot;分割句子（中文优化）&quot;&quot;&quot;
        sentence_endings = r&apos;[。！？；\n]&apos;
        sentences = re.split(sentence_endings, text)
        sentences = [s.strip() for s in sentences if s.strip()]
        sentences = [s + &apos;。&apos; for s in sentences if not s.endswith((&apos;。&apos;, &apos;!&apos;, &apos;?&apos;, &apos;;&apos;))]
        return sentences
    
    def _get_sentence_embeddings(self, sentences: List[str]) -&gt; List[np.ndarray]:
        &quot;&quot;&quot;获取句子嵌入&quot;&quot;&quot;
        embeddings = []
        for sent in sentences:
            emb = self.embeddings.embed_query(sent)
            embeddings.append(np.array(emb))
        return embeddings
    
    def split_text(self, text: str) -&gt; List[str]:
        &quot;&quot;&quot;基于语义相似度分块&quot;&quot;&quot;
        if not text:
            return [text]
        
        # 1. 按句子分割
        sentences = self._split_sentences(text)
        if len(sentences) &amp;#x3C;= 1:
            return [text]
        
        # 2. 计算每个句子的嵌入
        embeddings = self._get_sentence_embeddings(sentences)
        
        # 3. 计算相邻句子的相似度
        similarities = []
        for i in range(len(embeddings) - 1):
            sim = self.cosine_similarity(embeddings[i], embeddings[i+1])
            similarities.append(sim)
        
        # 4. 在相似度低的地方切分
        chunks = []
        current_chunk = [sentences[0]]
        
        for i, sim in enumerate(similarities):
            if sim &amp;#x3C; self.similarity_threshold:
                chunks.append(&apos;&apos;.join(current_chunk))
                current_chunk = [sentences[i+1]]
            else:
                current_chunk.append(sentences[i+1])
        
        # 添加最后一块
        if current_chunk:
            chunks.append(&apos;&apos;.join(current_chunk))
        
        return chunks

# 创建语义分块器
semantic_chunker = SemanticChunker(embeddings, similarity_threshold=0.5)

# 测试文本
text = &quot;&quot;&quot;
机器学习是人工智能的一个分支。它使计算机能够从数据中学习。
深度学习是机器学习的一个子领域。它使用神经网络来建模复杂模式。
Python是一种流行的编程语言。它广泛用于数据科学和机器学习。
Python有丰富的库生态系统。NumPy和Pandas是常用的数据处理库。
&quot;&quot;&quot;

chunks = semantic_chunker.split_text(text)

print(f&quot;语义分块结果: {len(chunks)} 块&quot;)
for i, chunk in enumerate(chunks):
    print(f&quot;\n块 {i+1}:\n{chunk}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;语义分块结果: 7 块

块 1:
机器学习是人工智能的一个分支。

块 2:
它使计算机能够从数据中学习。

块 3:
深度学习是机器学习的一个子领域。

块 4:
它使用神经网络来建模复杂模式。

块 5:
Python是一种流行的编程语言。

块 6:
它广泛用于数据科学和机器学习。

块 7:
Python有丰富的库生态系统。NumPy和Pandas是常用的数据处理库。
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.5 分块最佳实践&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入文档
    ↓
按文档结构分割（标题、段落）
    ↓
检查每个章节的大小
    ↓
对过长的章节进行二次分割
    ↓
为每个块添加元数据
    ↓
输出结构化分块结果
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from typing import List, Dict, Any

class SmartChunker:
    &quot;&quot;&quot;智能分块器 - 适配本地模型&quot;&quot;&quot;
    
    def __init__(
        self,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        min_chunk_size: int = 100,
        max_chunk_size: int = 2000
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.min_chunk_size = min_chunk_size
        self.max_chunk_size = max_chunk_size
    
    def split_by_structure(self, text: str) -&gt; List[str]:
        &quot;&quot;&quot;按文档结构分块&quot;&quot;&quot;
        if not text or len(text) &amp;#x3C; self.min_chunk_size:
            return [text]
        
        # 按标题分割
        sections = text.split(&apos;\n# &apos;)
        chunks = []
        
        for section in sections:
            if not section.strip():
                continue
            
            # 如果章节太长，进一步分割
            if len(section) &gt; self.max_chunk_size:
                sub_chunks = self._split_long_section(section)
                chunks.extend(sub_chunks)
            elif len(section) &gt;= self.min_chunk_size:
                chunks.append(section)
        
        return chunks
    
    def _split_long_section(self, text: str) -&gt; List[str]:
        &quot;&quot;&quot;分割过长的章节&quot;&quot;&quot;
        paragraphs = text.split(&apos;\n\n&apos;)
        chunks = []
        current_chunk = []
        current_length = 0
        
        for para in paragraphs:
            para_length = len(para)
            
            if current_length + para_length &gt; self.chunk_size and current_chunk:
                # 当前块已满，保存
                chunks.append(&apos;\n\n&apos;.join(current_chunk))
                
                # 处理重叠
                if len(current_chunk) &gt; 1:
                    current_chunk = [current_chunk[-1], para]
                    current_length = len(current_chunk[-1]) + para_length
                else:
                    current_chunk = [para]
                    current_length = para_length
            else:
                current_chunk.append(para)
                current_length += para_length
        
        if current_chunk:
            chunks.append(&apos;\n\n&apos;.join(current_chunk))
        
        return chunks
    
    def add_metadata_to_chunks(
        self,
        chunks: List[str],
        doc_metadata: dict
    ) -&gt; List[Dict[str, Any]]:
        &quot;&quot;&quot;为每个块添加元数据&quot;&quot;&quot;
        chunk_docs = []
        
        for i, chunk in enumerate(chunks):
            chunk_docs.append({
                &apos;content&apos;: chunk,
                &apos;metadata&apos;: {
                    **doc_metadata,
                    &apos;chunk_id&apos;: i,
                    &apos;chunk_total&apos;: len(chunks),
                    &apos;chunk_size&apos;: len(chunk)
                }
            })
        
        return chunk_docs

smart_chunker = SmartChunker(
    chunk_size=800,
    chunk_overlap=150,
    min_chunk_size=200
)

# 示例文档
document = &quot;&quot;&quot;
# 机器学习基础
机器学习是人工智能的一个重要分支。它使计算机能够从数据中学习。

# 深度学习
深度学习是机器学习的一个子领域。它使用神经网络来建模复杂模式。

# Python编程
Python是一种流行的编程语言。它广泛用于数据科学和机器学习。
Python有丰富的库生态系统。NumPy和Pandas是常用的数据处理库。
&quot;&quot;&quot;

# 按结构分块
chunks = smart_chunker.split_by_structure(document)

# 添加元数据
chunk_docs = smart_chunker.add_metadata_to_chunks(
    chunks,
    doc_metadata={&apos;source&apos;: &apos;python_tutorial.md&apos;, &apos;author&apos;: &apos;张三&apos;}
)

# 输出结果
for doc in chunk_docs:
    print(f&quot;\n块 {doc[&apos;metadata&apos;][&apos;chunk_id&apos;] + 1}/{doc[&apos;metadata&apos;][&apos;chunk_total&apos;]}&quot;)
    print(f&quot;大小: {doc[&apos;metadata&apos;][&apos;chunk_size&apos;]} 字符&quot;)
    print(f&quot;内容: {doc[&apos;content&apos;][:100]}...&quot;)
    print(f&quot;metadata: {doc[&apos;metadata&apos;]}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;块 1/1
大小: 166 字符
内容: 
# 机器学习基础
机器学习是人工智能的一个重要分支。它使计算机能够从数据中学习。

# 深度学习
深度学习是机器学习的一个子领域。它使用神经网络来建模复杂模式。

# Python编程
Python...
metadata: {&apos;source&apos;: &apos;python_tutorial.md&apos;, &apos;author&apos;: &apos;张三&apos;, &apos;chunk_id&apos;: 0, &apos;chunk_total&apos;: 1, &apos;chunk_size&apos;: 166}
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.6 分块策略总结&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;选择指南：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;固定长度分块&lt;/strong&gt;：适合格式规整的文档，处理速度快&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;句子级分块&lt;/strong&gt;：适合自然语言文本，保持语义完整性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;语义分块&lt;/strong&gt;：适合复杂文档，保证语义连贯性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;智能分块&lt;/strong&gt;：综合最优方案，推荐生产环境使用&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;最佳实践建议：&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;根据文档类型选择合适的分块策略&lt;/li&gt;
&lt;li&gt;设置合理的块大小和重叠区域&lt;/li&gt;
&lt;li&gt;为每个块添加丰富的元数据&lt;/li&gt;
&lt;li&gt;考虑文档的语义边界和结构特点&lt;/li&gt;
&lt;li&gt;测试不同分块策略对检索效果的影响&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;通过智能分块策略，我们可以显著提升RAG系统的检索质量和准确性，为后续的高级检索技术奠定坚实基础。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 2: 多向量索引 - Multi-Vector Indexing&lt;/h2&gt;
&lt;h3&gt;2.1 核心概念&lt;/h3&gt;
&lt;p&gt;多向量索引为单个文档生成多个向量，每个向量代表文档的不同方面或部分。这样可以更全面地表示文档的语义。&lt;/p&gt;
&lt;h3&gt;2.2 为什么需要多向量索引？&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;场景：一篇包含多个主题的文档&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;文档内容:
&quot;&quot;&quot;
本文介绍Python编程基础。

第一部分：变量和数据类型
Python支持多种数据类型...

第二部分：函数和模块
函数是可重用的代码块...

第三部分：面向对象编程
类是对象的蓝图...
&quot;&quot;&quot;

用户查询: &quot;Python中的类是什么？&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;单向量索引的问题：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;整个文档被表示为一个向量&lt;/li&gt;
&lt;li&gt;向量混合了所有主题的语义&lt;/li&gt;
&lt;li&gt;可能无法精确匹配&quot;类&quot;的相关内容&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;多向量索引的优势：&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;为每个部分生成独立向量&lt;/li&gt;
&lt;li&gt;&quot;面向对象编程&quot;部分的向量更匹配查询&lt;/li&gt;
&lt;li&gt;检索更精确&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;2.3 实现多向量检索器&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;处理流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;输入文档
    ↓
为每个文档生成唯一ID
    ↓
存储完整文档到文档存储
    ↓
将文档分割成小块
    ↓
为每个小块添加文档ID元数据
    ↓
将小块向量化并存储到向量数据库
    ↓
构建多向量检索器
    ↓
完成索引构建
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class MultiVectorIndexer:
    &quot;&quot;&quot;多向量索引器&quot;&quot;&quot;
    
    def __init__(self, embeddings, persist_directory: str = &quot;./chroma_db&quot;):
        self.embeddings = embeddings
        self.persist_directory = persist_directory
        
        # 先删除现有集合
        try:
            existing_chroma = Chroma(
                collection_name=&quot;multi_vector&quot;,
                persist_directory=persist_directory,
                embedding_function=embeddings
            )
            existing_chroma.delete_collection()
            print(&quot;🗑️ 已删除现有集合&quot;)
        except:
            print(&quot;ℹ️ 无需删除集合，继续初始化&quot;)
        
        # 重新创建向量存储
        self.vectorstore = Chroma(
            collection_name=&quot;multi_vector&quot;,
            persist_directory=persist_directory,
            embedding_function=embeddings
        )
        
        # 文档存储：存储完整文档
        self.docstore = InMemoryStore()
        
        # 多向量检索器（返回完整文档）
        self.retriever = MultiVectorRetriever(
            vectorstore=self.vectorstore,
            docstore=self.docstore,
            id_key=&quot;doc_id&quot;
        )
    
    def index_documents(self, documents: List[str], doc_ids: List[str] = None):
        &quot;&quot;&quot;索引文档&quot;&quot;&quot;
        if doc_ids is None:
            doc_ids = [str(uuid.uuid4()) for _ in documents]
        
        print(f&quot;📝 索引 {len(documents)} 个文档，生成 {len(doc_ids)} 个文档ID&quot;)
        
        # 1. 存储完整文档
        for doc_id, doc in zip(doc_ids, documents):
            self.docstore.mset([(doc_id, doc)])
            print(f&quot;✅ 存储文档: {doc_id} (长度: {len(doc)} 字符)&quot;)
        
        # 2. 将每个文档分成小块
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=100,
            chunk_overlap=10
        )
        
        sub_docs = []
        for doc_id, doc in zip(doc_ids, documents):
            chunks = text_splitter.split_text(doc)
            print(f&quot;📊 文档 {doc_id} 分割为 {len(chunks)} 个小块&quot;)
            
            for chunk in chunks:
                sub_docs.append(
                    Document(
                        page_content=chunk,
                        metadata={&quot;doc_id&quot;: doc_id}
                    )
                )
        
        # 3. 将小块向量化并存储
        if sub_docs:
            self.vectorstore.add_documents(sub_docs)
            self.vectorstore.persist()
            print(f&quot;✅ 存储 {len(sub_docs)} 个小块到向量数据库&quot;)
        else:
            print(&quot;⚠️ 没有小块可存储&quot;)
    
    def retrieve_full_documents(self, query: str, k: int = 4):
        &quot;&quot;&quot;检索完整文档（返回字符串）&quot;&quot;&quot;
        docs = self.retriever.get_relevant_documents(query, k=k)
        return docs
    
    def retrieve_chunks(self, query: str, k: int = 4):
        &quot;&quot;&quot;检索小块（返回Document对象）&quot;&quot;&quot;
        docs = self.vectorstore.similarity_search(query, k=k)
        return docs
    
    def debug_info(self):
        &quot;&quot;&quot;调试信息&quot;&quot;&quot;
        print(&quot;🔍 调试信息:&quot;)
        
        try:
            # 检查向量数据库
            all_docs = self.vectorstore.get()
            print(f&quot;📊 向量数据库: {len(all_docs[&apos;documents&apos;])} 个小块&quot;)
            
            # 检查文档存储
            all_doc_ids = list(set([metadata[&apos;doc_id&apos;] for metadata in all_docs[&apos;metadatas&apos;]]))
            print(f&quot;📚 文档存储: {len(all_doc_ids)} 个文档ID&quot;)
            
            for doc_id in all_doc_ids:
                doc = self.docstore.mget([doc_id])[0]
                if doc:
                    print(f&quot;✅ 文档 {doc_id}: 存在 ({len(doc)} 字符)&quot;)
                else:
                    print(f&quot;❌ 文档 {doc_id}: 不存在&quot;)
        except Exception as e:
            print(f&quot;⚠️ 调试信息获取失败: {e}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.4 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建多向量索引器，指定Chroma数据库路径
multi_indexer = MultiVectorIndexer(
    embeddings=embeddings,
    persist_directory=&quot;./chroma_db&quot;  # 本地Chroma数据库路径
)

# 准备文档
documents = [
    &quot;&quot;&quot;
    Python编程基础教程
    
    第一章：变量和数据类型
    Python是动态类型语言，支持整数、浮点数、字符串等多种数据类型。
    
    第二章：控制流
    Python使用if、for、while等关键字进行流程控制。
    
    第三章：函数
    函数是可重用的代码块，使用def关键字定义。
    &quot;&quot;&quot;,
    
    &quot;&quot;&quot;
    机器学习入门
    
    监督学习：使用标记数据训练模型。
    无监督学习：从无标记数据中发现模式。
    强化学习：通过奖励机制学习最优策略。
    &quot;&quot;&quot;
]

# 索引文档
multi_indexer.index_documents(documents)

# 调试信息
multi_indexer.debug_info()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📝 索引 2 个文档，生成 2 个文档ID
✅ 存储文档: 36cf53c2-968a-4540-9f75-e2e0e02a9d04 (长度: 178 字符)
✅ 存储文档: 8e2af8dd-e458-4f04-bc14-01d02552e8c9 (长度: 88 字符)
📊 文档 36cf53c2-968a-4540-9f75-e2e0e02a9d04 分割为 2 个小块
📊 文档 8e2af8dd-e458-4f04-bc14-01d02552e8c9 分割为 1 个小块
✅ 存储 3 个小块到向量数据库
🔍 调试信息:
📊 向量数据库: 3 个小块
📚 文档存储: 2 个文档ID
✅ 文档 36cf53c2-968a-4540-9f75-e2e0e02a9d04: 存在 (178 字符)
✅ 文档 8e2af8dd-e458-4f04-bc14-01d02552e8c9: 存在 (88 字符)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.5 检索测试&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;测试1：机器学习相关查询&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;print(&quot;\n🔍 检索测试:&quot;)
results = multi_indexer.retrieve_full_documents(&quot;什么是机器学习？&quot;, k=2)
for i, doc in enumerate(results):
    print(f&quot;结果 {i+1}: {doc[:100]}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🔍 检索测试:
结果 1: 
    机器学习入门
    
    监督学习：使用标记数据训练模型。
    无监督学习：从无标记数据中发现模式。
    强化学习：通过奖励机制学习最优策略。
    ...
结果 2: 
    Python编程基础教程
    
    第一章：变量和数据类型
    Python是动态类型语言，支持整数、浮点数、字符串等多种数据类型。
    
    第二章：控制流
    P...
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;测试2：Python编程相关查询&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;print(&quot;\n🔍 检索测试:&quot;)
results = multi_indexer.retrieve_full_documents(&quot;什么是python编程？&quot;, k=2)
for i, doc in enumerate(results):
    print(f&quot;结果 {i+1}: {doc[:100]}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🔍 检索测试:
结果 1: 
    Python编程基础教程
    
    第一章：变量和数据类型
    Python是动态类型语言，支持整数、浮点数、字符串等多种数据类型。
    
    第二章：控制流
    P...
结果 2: 
    机器学习入门
    
    监督学习：使用标记数据训练模型。
    无监督学习：从无标记数据中发现模式。
    强化学习：通过奖励机制学习最优策略。
    ...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.6 多向量检索机制详解&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;检索流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询 &quot;什么是机器学习？&quot;
    ↓
在向量数据库中搜索相似的小块
    ↓
找到相关小块：[小块2, 小块5, 小块8...]
    ↓
获取小块的 doc_id：[doc_id_1, doc_id_2, doc_id_1...]
    ↓
根据 doc_id 从文档存储中检索完整文档
    ↓
返回完整文档：[完整文档1, 完整文档2...]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.7 多向量检索优缺点分析&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;优点：&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;更精确的语义表示&lt;/strong&gt;：每个文档部分都有独立的向量表示&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;提高相关性得分&lt;/strong&gt;：可以精确匹配文档的特定部分&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;灵活的检索策略&lt;/strong&gt;：支持多种检索模式（完整文档 vs 小块）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;更好的上下文理解&lt;/strong&gt;：返回完整文档保持上下文完整性&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;缺点：&lt;/strong&gt; ⚠️&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;存储成本增加&lt;/strong&gt;：需要存储更多向量和元数据&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;索引时间更长&lt;/strong&gt;：需要为每个文档生成多个向量&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实现复杂度高&lt;/strong&gt;：需要管理文档存储和向量存储的同步&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;检索延迟增加&lt;/strong&gt;：需要额外的文档查找步骤&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;2.8 适用场景建议&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;推荐使用多向量索引当：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;文档包含多个独立主题或章节&lt;/li&gt;
&lt;li&gt;需要精确匹配文档的特定部分&lt;/li&gt;
&lt;li&gt;检索质量比存储成本更重要&lt;/li&gt;
&lt;li&gt;文档结构清晰，可以自然分割&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;选择单向量索引当：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;文档内容单一，语义集中&lt;/li&gt;
&lt;li&gt;存储资源有限&lt;/li&gt;
&lt;li&gt;需要快速索引和检索&lt;/li&gt;
&lt;li&gt;文档结构不清晰，难以分割&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;多向量索引技术通过为文档的不同部分创建独立的向量表示，显著提升了检索的精确性和相关性，是处理复杂文档结构的有效解决方案。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 3: 父文档检索器 - Parent Document Retriever&lt;/h2&gt;
&lt;h3&gt;3.1 核心概念&lt;/h3&gt;
&lt;p&gt;父文档检索器的策略是：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;将文档分成小块进行索引和检索&lt;/strong&gt;（精确匹配）&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;返回大块或完整文档给LLM&lt;/strong&gt;（保持上下文）&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;3.2 工作原理&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;索引阶段:&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;大文档
    ↓ 分块
小块1、小块2、小块3
    ↓ 向量化
存储在向量数据库
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;检索阶段:&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询
    ↓ 检索
匹配小块2
    ↓ 查找父文档
返回：包含小块2的大块/完整文档
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.3 实现父文档检索器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma
from langchain.schema import Document
from typing import List
import shutil

class ParentDocRetriever:
    &quot;&quot;&quot;父文档检索器（修复元数据键问题）&quot;&quot;&quot;
    
    def __init__(self, embeddings, persist_directory: str = &quot;./chroma_db&quot;):
        self.embeddings = embeddings
        self.persist_directory = persist_directory
        self.collection_name = &quot;parent_doc&quot;
        # 清理指定集合
        self._clean_collection()
        # 向量存储（存储子文档）
        self.vectorstore = Chroma(
            collection_name=self.collection_name,
            persist_directory=persist_directory,
            embedding_function=embeddings
        )
        
        # 文档存储（存储父文档）
        self.docstore = InMemoryStore()
        
        # 子文档分割器（用于检索的小块）
        self.child_splitter = RecursiveCharacterTextSplitter(
            chunk_size=400,
            chunk_overlap=50
        )
        
        # 父文档分割器（用于返回的大块）
        self.parent_splitter = RecursiveCharacterTextSplitter(
            chunk_size=2000,
            chunk_overlap=200
        )
        
    def _clean_collection(self):
        &quot;&quot;&quot;清理指定集合&quot;&quot;&quot;
        try:
            # 先尝试连接到现有集合
            existing_vectorstore = Chroma(
                collection_name=self.collection_name,
                persist_directory=self.persist_directory,
                embedding_function=self.embeddings
            )
            
            # 删除集合
            existing_vectorstore.delete_collection()
            print(f&quot;🗑️ 清理集合: {self.collection_name}&quot;)
            
        except Exception as e:
            # 如果集合不存在，创建目录
            import os
            os.makedirs(self.persist_directory, exist_ok=True)
            print(f&quot;ℹ️ 集合不存在，将创建新集合: {self.collection_name}&quot;)
        
    def add_documents(self, documents: List[Document]):
        &quot;&quot;&quot;添加文档（修复元数据键问题）&quot;&quot;&quot;
        print(&quot;📝 开始索引文档...&quot;)
        
        for i, doc in enumerate(documents):
            print(f&quot;\n📄 处理文档 {i+1}/{len(documents)}&quot;)
            
            # 1. 用父文档分割器分割成大块
            parent_chunks = self.parent_splitter.split_documents([doc])
            print(f&quot;   🔧 父文档分割: {len(parent_chunks)} 个大块&quot;)
            
            for j, parent_chunk in enumerate(parent_chunks):
                # 为每个父文档生成唯一ID
                parent_id = f&quot;parent_{i}_{j}&quot;
                
                # 2. 用子文档分割器将大块分割成小块
                child_chunks = self.child_splitter.split_documents([parent_chunk])
                print(f&quot;   🔍 子文档分割: 大块 {j+1} → {len(child_chunks)} 个小块&quot;)
                
                # 3. 存储父文档到文档存储
                self.docstore.mset([(parent_id, parent_chunk.page_content)])
                
                # 4. 为每个子文档添加父文档ID元数据，并存储到向量数据库
                child_docs_with_metadata = []
                for k, child_chunk in enumerate(child_chunks):
                    child_doc = Document(
                        page_content=child_chunk.page_content,
                        metadata={&quot;doc_id&quot;: parent_id, &quot;source&quot;: doc.metadata.get(&quot;source&quot;, &quot;unknown&quot;)}
                    )
                    child_docs_with_metadata.append(child_doc)
                
                self.vectorstore.add_documents(child_docs_with_metadata)
        
        # 持久化保存
        self.vectorstore.persist()
        print(f&quot;✅ 索引完成: 所有文档已存储&quot;)
    
    def retrieve(self, query: str, k: int = 2):
        &quot;&quot;&quot;检索父文档（修复元数据键问题）&quot;&quot;&quot;
        print(f&quot;\n🔍 开始检索: &apos;{query}&apos;&quot;)
        
        # 1. 在向量数据库中搜索相似子文档
        print(&quot;1. 🔎 在向量数据库中搜索相似子文档...&quot;)
        child_docs = self.vectorstore.similarity_search(query, k=k*3)
        print(f&quot;   找到 {len(child_docs)} 个相关子文档&quot;)
        
        for i, child_doc in enumerate(child_docs):
            # 使用正确的元数据键 &apos;doc_id&apos; 而不是 &apos;parent_id&apos;
            parent_id = child_doc.metadata.get(&quot;doc_id&quot;)
            print(f&quot;     子文档 {i+1}: {child_doc.page_content[:50]}... (父文档ID: {parent_id})&quot;)
        
        # 2. 提取父文档ID
        print(&quot;\n2. 🆔 提取父文档ID...&quot;)
        parent_ids = [doc.metadata.get(&quot;doc_id&quot;) for doc in child_docs]
        print(f&quot;   父文档ID列表: {parent_ids}&quot;)
        
        # 3. 从文档存储中检索完整父文档
        print(&quot;\n3. 📚 从文档存储中检索完整父文档...&quot;)
        parent_docs = self.docstore.mget(parent_ids)
        print(f&quot;   检索到 {len([doc for doc in parent_docs if doc])} 个父文档&quot;)
        
        # 4. 去重和排序
        print(&quot;\n4. 🔄 去重和排序...&quot;)
        seen = set()
        unique_docs = []
        
        for i, parent_doc in enumerate(parent_docs):
            if parent_doc and parent_doc not in seen:
                seen.add(parent_doc)
                unique_docs.append(Document(
                    page_content=parent_doc,
                    metadata={&quot;doc_id&quot;: parent_ids[i]}
                ))
                print(f&quot;   保留父文档: {parent_ids[i]} (长度: {len(parent_doc)} 字符)&quot;)
        
        # 5. 返回前k个结果
        final_results = unique_docs[:k]
        print(f&quot;\n✅ 检索完成: 返回 {len(final_results)} 个父文档&quot;)
        
        return final_results
    
    def debug_storage(self):
        &quot;&quot;&quot;调试存储状态&quot;&quot;&quot;
        print(&quot;\n📊 存储状态调试:&quot;)
        
        # 检查向量数据库中的子文档
        all_child_docs = self.vectorstore.get()
        print(f&quot;🔍 向量数据库: {len(all_child_docs[&apos;documents&apos;])} 个子文档&quot;)
        
        # 检查元数据
        if all_child_docs[&apos;metadatas&apos;]:
            first_metadata = all_child_docs[&apos;metadatas&apos;][0]
            print(f&quot;   第一个子文档的元数据: {first_metadata}&quot;)
            
            # 获取所有父文档ID（使用正确的键 &apos;doc_id&apos;）
            parent_ids = []
            for metadata in all_child_docs[&apos;metadatas&apos;]:
                if metadata and &apos;doc_id&apos; in metadata:
                    parent_ids.append(metadata[&apos;doc_id&apos;])
            
            print(f&quot;📚 文档存储: {len(set(parent_ids))} 个父文档ID&quot;)
            print(f&quot;   父文档ID列表: {list(set(parent_ids))}&quot;)
        else:
            print(&quot;⚠️ 没有找到元数据&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.4 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建父文档检索器
parent_retriever = ParentDocRetriever(
    embeddings=embeddings,
    persist_directory=&quot;./chroma_db&quot;
)

# 准备文档
docs = [
    Document(
        page_content=&quot;&quot;&quot;
        Python编程语言完整指南
        
        第一部分：基础语法
        Python使用缩进来定义代码块。变量不需要声明类型。
        支持多种数据类型，包括整数、浮点数、字符串、列表、字典等。
        
        第二部分：控制流
        if语句用于条件判断：
        if condition:
            do_something()
        
        for循环用于迭代：
        for item in items:
            process(item)
        
        while循环用于重复执行：
        while condition:
            do_something()
        
        第三部分：函数
        使用def关键字定义函数：
        def function_name(parameters):
            # function body
            return result
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_guide.md&quot;}
    )
]

# 添加文档
parent_retriever.add_documents(docs)

# 调试存储状态
parent_retriever.debug_storage()
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📝 开始索引文档...

📄 处理文档 1/1
   🔧 父文档分割: 1 个大块
   🔍 子文档分割: 大块 1 → 2 个小块
✅ 索引完成: 所有文档已存储

📊 存储状态调试:
🔍 向量数据库: 2 个子文档
   第一个子文档的元数据: {&apos;doc_id&apos;: &apos;parent_0_0&apos;, &apos;source&apos;: &apos;python_guide.md&apos;}
📚 文档存储: 1 个父文档ID
   父文档ID列表: [&apos;parent_0_0&apos;]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.5 检索测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;print(&quot;\n&quot; + &quot;=&quot;*60)
results = parent_retriever.retrieve(&quot;Python中for循环怎么用？&quot;, k=1)

# 显示结果
print(&quot;\n🎯 最终结果:&quot;)
for i, doc in enumerate(results):
    print(f&quot;文档 {i+1}:&quot;)
    print(f&quot;内容预览: {doc.page_content[:100]}...&quot;)
    print(f&quot;父文档ID: {doc.metadata[&apos;doc_id&apos;]}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;============================================================

🔍 开始检索: &apos;Python中for循环怎么用？&apos;
1. 🔎 在向量数据库中搜索相似子文档...
   找到 2 个相关子文档
     子文档 1: Python编程语言完整指南
        
        第一部分：基础语法
        ... (父文档ID: parent_0_0)
     子文档 2: 第三部分：函数
        使用def关键字定义函数：
        def function... (父文档ID: parent_0_0)

2. 🆔 提取父文档ID...
   父文档ID列表: [&apos;parent_0_0&apos;, &apos;parent_0_0&apos;]

3. 📚 从文档存储中检索完整父文档...
   检索到 2 个父文档

4. 🔄 去重和排序...
   保留父文档: parent_0_0 (长度: 515 字符)

✅ 检索完成: 返回 1 个父文档

🎯 最终结果:
文档 1:
内容预览: Python编程语言完整指南
        
        第一部分：基础语法
        Python使用缩进来定义代码块。变量不需要声明类型。
        支持多种数据类型，包括整数、...
父文档ID: parent_0_0
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.6 父文档检索的优势&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;最佳场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;✅ &lt;strong&gt;需要精确匹配 + 完整上下文&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;✅ &lt;strong&gt;文档有明确的层次结构&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;✅ &lt;strong&gt;答案需要周围的解释&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;✅ &lt;strong&gt;避免上下文丢失&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;技术优势：&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;检索精度高&lt;/strong&gt;：使用小块进行相似性搜索，匹配更精确&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;上下文完整&lt;/strong&gt;：返回父文档级别的完整内容&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;避免信息碎片化&lt;/strong&gt;：保持相关内容的连贯性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;支持复杂查询&lt;/strong&gt;：能够处理需要多段上下文的复杂问题&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;适用场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;技术文档检索&lt;/strong&gt;：如API文档、编程教程&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;学术论文搜索&lt;/strong&gt;：需要完整段落理解概念&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;法律文档分析&lt;/strong&gt;：需要完整条款上下文&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;医疗记录查询&lt;/strong&gt;：需要完整病历信息&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;3.7 与其他检索策略对比&lt;/h3&gt;
&lt;p&gt;| 策略 | 检索粒度 | 返回粒度 | 优势 | 劣势 |
|------|----------|----------|------|------|
| &lt;strong&gt;标准检索&lt;/strong&gt; | 文档块 | 文档块 | 简单快速 | 上下文可能不完整 |
| &lt;strong&gt;多向量索引&lt;/strong&gt; | 文档小块 | 完整文档 | 精确匹配 | 存储成本高 |
| &lt;strong&gt;父文档检索&lt;/strong&gt; | 文档小块 | 父文档块 | &lt;strong&gt;平衡精度和上下文&lt;/strong&gt; | 实现复杂度中等 |&lt;/p&gt;
&lt;p&gt;父文档检索器通过&quot;小块检索，大块返回&quot;的策略，在保持检索精度的同时提供了完整的上下文信息，是处理需要深度理解的复杂查询的理想选择。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 4: 上下文压缩检索 - Contextual Compression&lt;/h2&gt;
&lt;h3&gt;4.1 核心概念&lt;/h3&gt;
&lt;p&gt;上下文压缩检索在检索后对文档进行过滤和压缩，只保留与查询最相关的内容。&lt;/p&gt;
&lt;h3&gt;4.2 为什么需要压缩？&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;问题：检索冗余&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;用户查询: &quot;Python中的列表推导式是什么？&quot;

检索到的文档:
&quot;&quot;&quot;
Python高级特性完整指南

1. 列表推导式
列表推导式是创建列表的简洁方式...

2. 生成器表达式
生成器用于惰性计算...

3. 装饰器
装饰器用于修改函数行为...

4. 上下文管理器
with语句用于资源管理...
&quot;&quot;&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;问题：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;→ 只有&quot;列表推导式&quot;部分相关&lt;/li&gt;
&lt;li&gt;→ 其他部分是噪声&lt;/li&gt;
&lt;li&gt;→ 浪费LLM的上下文窗口&lt;/li&gt;
&lt;li&gt;→ 可能影响答案质量&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;解决方案：上下文压缩&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;→ 只提取相关部分：&quot;列表推导式是创建列表的简洁方式...&quot;&lt;/li&gt;
&lt;li&gt;→ 节省tokens&lt;/li&gt;
&lt;li&gt;→ 提高答案精度&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;4.3 实现LLM过滤器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.vectorstores import Chroma
from langchain.schema import Document
from vllm import SamplingParams

class VLLMCompressedRetriever:
    &quot;&quot;&quot;vLLM压缩检索器（ChatML格式）&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm):
        self.vectorstore = vectorstore
        self.llm = llm
        self.base_retriever = vectorstore.as_retriever()
    
    def retrieve(self, query: str, k: int = 3):
        &quot;&quot;&quot;检索并压缩&quot;&quot;&quot;
        # 基础检索
        base_docs = self.base_retriever.get_relevant_documents(query, k=k)
        
        # 压缩文档
        compressed_docs = []
        for doc in base_docs:
            # 使用ChatML格式
            prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
                你是一个文档压缩助手。请从给定的文档中提取与用户查询相关的关键信息，去除无关内容。
                
                请只返回提取的关键信息，不要添加任何解释或评论。&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;user
                查询：{query}
                
                文档内容：
                {doc.page_content}
                
                请提取与查询相关的关键信息：&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;assistant
                &quot;&quot;&quot;
            
            sampling_params = SamplingParams(
                temperature=0.1,
                top_p=0.9,
                max_tokens=300,
                stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;, &quot;&amp;#x3C;|endoftext|&gt;&quot;]
            )
            
            outputs = self.llm.generate([prompt], sampling_params)
            if outputs and outputs[0].outputs:
                compressed_content = outputs[0].outputs[0].text.strip()
                compressed_doc = Document(
                    page_content=compressed_content,
                    metadata=doc.metadata
                )
                compressed_docs.append(compressed_doc)
        
        return compressed_docs
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.4 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建向量数据库
vectorstore = Chroma(
    collection_name=&quot;docs&quot;,
    persist_directory=&quot;./chroma_db&quot;,
    embedding_function=embeddings
)

# 添加文档
docs = [
    Document(
        page_content=&quot;&quot;&quot;
        Python高级特性
        
        列表推导式：
        列表推导式是Python中创建列表的简洁方式。
        语法：[expression for item in iterable if condition]
        示例：squares = [x**2 for x in range(10)]
        
        生成器表达式：
        生成器用于惰性计算，节省内存。
        语法：(expression for item in iterable)
        
        装饰器：
        装饰器用于修改函数行为，不改变原函数代码。
        使用@符号应用装饰器。
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_advanced.md&quot;}
    )
]

vectorstore.add_documents(docs)

# 创建压缩检索器
compressed_retriever = VLLMCompressedRetriever(vectorstore, llm)

# 测试查询
query = &quot;Python列表推导式&quot;

print(&quot;🔍 普通检索:&quot;)
normal_docs = base_retriever.get_relevant_documents(query)
print(f&quot;长度: {len(normal_docs[0].page_content)} 字符&quot;)
print(normal_docs[0].page_content)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🔍 普通检索:
长度: 355 字符

        Python高级特性
        
        列表推导式：
        列表推导式是Python中创建列表的简洁方式。
        语法：[expression for item in iterable if condition]
        示例：squares = [x**2 for x in range(10)]
        
        生成器表达式：
        生成器用于惰性计算，节省内存。
        语法：(expression for item in iterable)
        
        装饰器：
        装饰器用于修改函数行为，不改变原函数代码。
        使用@符号应用装饰器。
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;print(&quot;\n✂️ 压缩检索:&quot;)
compressed_docs = compressed_retriever.retrieve(query)
print(f&quot;长度: {len(compressed_docs[0].page_content)} 字符&quot;)
print(compressed_docs[0].page_content)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;✂️ 压缩检索:
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.64it/s]
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.67it/s]
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.67it/s]
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.66it/s]
长度: 72 字符
列表推导式是Python中创建列表的简洁方式。语法：[expression for item in iterable if condition]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.5 实现嵌入过滤器&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;操作流程：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询: &quot;Python列表推导式&quot;
    ↓
基础检索器检索相关文档
    ↓
获取多个候选文档
    ↓
计算查询与每个文档的嵌入相似度
    ↓
过滤相似度低于阈值的文档
    ↓
返回高相似度文档
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.vectorstores import Chroma

class EmbeddingFilterRetriever:
    &quot;&quot;&quot;基于嵌入相似度的过滤器&quot;&quot;&quot;
    
    def __init__(self, base_retriever, embeddings, similarity_threshold: float = 0.75):
        # 创建嵌入过滤器
        compressor = EmbeddingsFilter(
            embeddings=embeddings,
            similarity_threshold=similarity_threshold
        )
        
        # 创建压缩检索器
        self.retriever = ContextualCompressionRetriever(
            base_compressor=compressor,
            base_retriever=base_retriever
        )
    
    def retrieve(self, query: str, k: int = 5):
        &quot;&quot;&quot;检索并过滤&quot;&quot;&quot;
        docs = self.retriever.get_relevant_documents(query, k=k)
        return docs

# 创建基础检索器
base_retriever = vectorstore.as_retriever()

# 创建嵌入过滤器
embedding_filter = EmbeddingFilterRetriever(
    base_retriever=base_retriever,
    embeddings=embeddings,
    similarity_threshold=0.6
)

# 检索测试
query = &quot;Python列表推导式&quot;
filtered_docs = embedding_filter.retrieve(query, k=5)
for i in range(len(filtered_docs)):
    print(f&quot;Doc_{i}: {filtered_docs[i].page_content}&quot;)
print(f&quot;过滤后保留 {len(filtered_docs)} 个文档&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Doc_0: 
        Python高级特性
        
        列表推导式：
        列表推导式是Python中创建列表的简洁方式。
        语法：[expression for item in iterable if condition]
        示例：squares = [x**2 for x in range(10)]
        
        生成器表达式：
        生成器用于惰性计算，节省内存。
        语法：(expression for item in iterable)
        
        装饰器：
        装饰器用于修改函数行为，不改变原函数代码。
        使用@符号应用装饰器。
        
Doc_1: 
        Python高级特性
        
        列表推导式：
        列表推导式是Python中创建列表的简洁方式。
        语法：[expression for item in iterable if condition]
        示例：squares = [x**2 for x in range(10)]
        
        生成器表达式：
        生成器用于惰性计算，节省内存。
        语法：(expression for item in iterable)
        
        装饰器：
        装饰器用于修改函数行为，不改变原函数代码。
        使用@符号应用装饰器。
        
过滤后保留 4 个文档
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.6 实现文档分割过滤器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.retrievers import ContextualCompressionRetriever
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma

class SimplePipelineCompressor:
    &quot;&quot;&quot;简化版管道压缩器&quot;&quot;&quot;
    
    def __init__(self, base_retriever, embeddings, llm):
        self.base_retriever = base_retriever
        self.embeddings = embeddings
        self.llm = llm
        
        # 文本分割器
        self.splitter = CharacterTextSplitter(
            chunk_size=100,
            chunk_overlap=10,
            separator=&quot;. &quot;
        )
    
    def retrieve(self, query: str):
        &quot;&quot;&quot;手动实现管道压缩&quot;&quot;&quot;
        # 1. 基础检索
        base_docs = self.base_retriever.get_relevant_documents(query)
        
        # 2. 文本分割
        split_docs = []
        for doc in base_docs:
            chunks = self.splitter.split_text(doc.page_content)
            for chunk in chunks:
                split_docs.append({
                    &apos;content&apos;: chunk,
                    &apos;metadata&apos;: doc.metadata
                })
        print(f&quot;length of split_docs: {len(split_docs)}&quot;)
        
        # 3. 嵌入过滤（简化版）
        filtered_docs = self._embedding_filter(split_docs, query)
        
        # 4. LLM提取（简化版）
        final_docs = self._llm_extract(filtered_docs, query)
        
        return final_docs
    
    def _embedding_filter(self, docs, query):
        &quot;&quot;&quot;嵌入相似度过滤&quot;&quot;&quot;
        from sklearn.metrics.pairwise import cosine_similarity
        import numpy as np
        
        # 计算查询嵌入
        query_embedding = self.embeddings.embed_query(query)
        
        filtered_docs = []
        for doc in docs:
            # 计算文档嵌入
            doc_embedding = self.embeddings.embed_query(doc[&apos;content&apos;])
            
            # 计算相似度
            similarity = cosine_similarity([query_embedding], [doc_embedding])[0][0]
            
            # 应用阈值
            if similarity &gt;= 0.5:
                filtered_docs.append(doc)
        
        return filtered_docs
    
    def _llm_extract(self, docs, query):
        &quot;&quot;&quot;LLM内容提取&quot;&quot;&quot;
        from vllm import SamplingParams
        
        final_docs = []
        
        for doc in docs:
            # 构建ChatML格式提示
            prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
                你是一个文档压缩助手。请从给定的文本中提取与用户查询相关的关键信息。
                
                请只返回提取的关键信息，不要添加任何解释或评论。&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;user
                查询：{query}
                
                文本内容：
                {doc[&apos;content&apos;]}
                
                请提取与查询相关的关键信息：&amp;#x3C;|im_end|&gt;
                &amp;#x3C;|im_start|&gt;assistant
                &quot;&quot;&quot;
            
            sampling_params = SamplingParams(
                temperature=0.1,
                top_p=0.9,
                max_tokens=200,
                stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
            )
            
            outputs = self.llm.generate([prompt], sampling_params)
            if outputs and outputs[0].outputs:
                extracted_content = outputs[0].outputs[0].text.strip()
                
                from langchain.schema import Document
                final_doc = Document(
                    page_content=extracted_content,
                    metadata=doc[&apos;metadata&apos;]
                )
                final_docs.append(final_doc)
        
        return final_docs

# 创建压缩器
compressor = SimplePipelineCompressor(base_retriever, embeddings, llm)

# 检索
results = compressor.retrieve(&quot;Python列表推导式&quot;)
print(results)
for i, doc in enumerate(results):
    print(f&quot;结果 {i+1}: {doc.page_content}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;length of split_docs: 4
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.22it/s]
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.20it/s]
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.29it/s]
Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  1.77it/s]
[Document(page_content=&apos;Python列表推导式语法：[expression for item in iterable if condition]\n                 示例：squares = [x**2 for x in range(10)]\n                 生成器表达式：(expression for item in iterable)\n                 装饰器：使用@符号应用装饰器。&apos;, metadata={&apos;source&apos;: &apos;python_advanced.md&apos;}), Document(page_content=&apos;Python列表推导式语法：[expression for item in iterable if condition]\n                 示例：squares = [x**2 for x in range(10)]\n                 生成器表达式语法：(expression for item in iterable)\n                 装饰器语法：@符号应用装饰器。&apos;, metadata={&apos;source&apos;: &apos;python_advanced.md&apos;}), Document(page_content=&apos;Python列表推导式语法：[expression for item in iterable if condition]\n                 示例：squares = [x**2 for x in range(10)]\n                 生成器表达式语法：(expression for item in iterable)\n                 装饰器使用@符号应用&apos;, metadata={&apos;source&apos;: &apos;python_advanced.md&apos;}), Document(page_content=&apos;列表推导式：[expression for item in iterable if condition]\n                 生成器表达式：(expression for item in iterable)\n                 装饰器：@符号应用装饰器&apos;, metadata={&apos;source&apos;: &apos;python_advanced.md&apos;})]
结果 1: Python列表推导式语法：[expression for item in iterable if condition]
                 示例：squares = [x**2 for x in range(10)]
                 生成器表达式：(expression for item in iterable)
                 装饰器：使用@符号应用装饰器。
结果 2: Python列表推导式语法：[expression for item in iterable if condition]
                 示例：squares = [x**2 for x in range(10)]
                 生成器表达式语法：(expression for item in iterable)
                 装饰器语法：@符号应用装饰器。
结果 3: Python列表推导式语法：[expression for item in iterable if condition]
                 示例：squares = [x**2 for x in range(10)]
                 生成器表达式语法：(expression for item in iterable)
                 装饰器使用@符号应用
结果 4: 列表推导式：[expression for item in iterable if condition]
                 生成器表达式：(expression for item in iterable)
                 装饰器：@符号应用装饰器
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.7 压缩策略对比&lt;/h3&gt;
&lt;p&gt;| 压缩器 | 方法 | 速度 | 质量 | 成本 |
|--------|------|------|------|------|
| &lt;strong&gt;LLM提取器&lt;/strong&gt; | LLM提取相关内容 | 慢 | 高 | 高 |
| &lt;strong&gt;嵌入过滤器&lt;/strong&gt; | 相似度过滤 | 快 | 中 | 低 |
| &lt;strong&gt;管道压缩&lt;/strong&gt; | 组合多种方法 | 中 | 高 | 中 |&lt;/p&gt;
&lt;h2&gt;4.8 技术优势与适用场景&lt;/h2&gt;
&lt;p&gt;&lt;strong&gt;优势：&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;显著节省tokens&lt;/strong&gt;：减少LLM处理的无用信息&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;提高答案质量&lt;/strong&gt;：专注于相关内容&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;降低计算成本&lt;/strong&gt;：减少API调用费用&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;提升响应速度&lt;/strong&gt;：处理更少的内容&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;适用场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;长文档检索&lt;/strong&gt;：技术文档、学术论文等&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多主题查询&lt;/strong&gt;：需要精确匹配特定部分&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;成本敏感应用&lt;/strong&gt;：需要控制API调用成本&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;实时系统&lt;/strong&gt;：需要快速响应的应用&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;选择指南：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;追求质量&lt;/strong&gt;：选择LLM提取器或管道压缩&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;追求速度&lt;/strong&gt;：选择嵌入过滤器&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;平衡方案&lt;/strong&gt;：管道压缩提供最佳平衡&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;上下文压缩检索技术通过智能过滤和内容提取，有效解决了检索冗余问题，是构建高效RAG系统的关键技术之一。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 5: 时间衰减检索 - Time-Weighted Retrieval&lt;/h2&gt;
&lt;h3&gt;5.1 核心概念&lt;/h3&gt;
&lt;p&gt;时间衰减检索考虑文档的新鲜度，给予新文档更高的权重。这种策略特别适用于新闻、技术文档、市场报告等时效性强的场景。&lt;/p&gt;
&lt;h3&gt;5.2 为什么需要时间衰减？&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;场景对比：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 旧文档（2023年）
old_doc = Document(
    page_content=&quot;2023年1月：Python 3.11发布&quot;,
    metadata={&quot;date&quot;: &quot;2023-01-01&quot;}
)

# 新文档（2024年）
recent_doc = Document(
    page_content=&quot;2024年10月：Python 3.13发布，性能提升显著&quot;,
    metadata={&quot;date&quot;: &quot;2024-10-01&quot;}
)

用户查询: &quot;Python最新版本&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;问题：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;标准检索可能返回旧文档（Python 3.11）&lt;/li&gt;
&lt;li&gt;用户期望获得最新信息（Python 3.13）&lt;/li&gt;
&lt;li&gt;时间因素影响答案的准确性和价值&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;解决方案：时间衰减检索&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;给予新文档更高的相关性权重&lt;/li&gt;
&lt;li&gt;自动平衡语义相关性和时效性&lt;/li&gt;
&lt;li&gt;确保返回最新、最准确的信息&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;5.3 实现时间加权检索器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.retrievers import TimeWeightedVectorStoreRetriever
import datetime

class TimeSensitiveRetriever:
    &quot;&quot;&quot;时间敏感检索器&quot;&quot;&quot;
    
    def __init__(self, vectorstore, decay_rate: float = 0.01):
        &quot;&quot;&quot;
        Args:
            decay_rate: 衰减率，越大则时间影响越大
        &quot;&quot;&quot;
        self.retriever = TimeWeightedVectorStoreRetriever(
            vectorstore=vectorstore,
            decay_rate=decay_rate,
            k=2
        )
    
    def add_documents(self, documents: List[Document]):
        &quot;&quot;&quot;添加文档（会自动记录时间）&quot;&quot;&quot;
        self.retriever.add_documents(documents)
    
    def retrieve(self, query: str):
        &quot;&quot;&quot;检索（考虑时间因素）&quot;&quot;&quot;
        docs = self.retriever.get_relevant_documents(query)
        print(f&quot;docs: {docs}&quot;)
        return docs
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;5.4 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 创建时间敏感检索器
vectorstore = Chroma(
    collection_name=&quot;news&quot;,
    persist_directory=&quot;./chroma_db&quot;,
    embedding_function=embeddings
)

time_retriever = TimeSensitiveRetriever(vectorstore, decay_rate=0.01)

# 准备不同时间的文档
old_doc = Document(
    page_content=&quot;2023年1月：Python 3.11发布，引入了新的语法特性和性能改进&quot;,
    metadata={&quot;date&quot;: &quot;2023-01-01&quot;, &quot;source&quot;: &quot;python_release_notes&quot;}
)

recent_doc = Document(
    page_content=&quot;2024年10月：Python 3.13发布，性能提升显著，新增了JIT编译器&quot;,
    metadata={&quot;date&quot;: &quot;2024-10-01&quot;, &quot;source&quot;: &quot;python_release_notes&quot;}
)

# 添加文档到检索器
time_retriever.add_documents([old_doc, recent_doc])

# 检索：新文档会获得更高权重
results = time_retriever.retrieve(&quot;Python最新版本&quot;)

print(&quot;\n检索结果（按时间加权）:&quot;)
for i, doc in enumerate(results):
    print(f&quot;\n结果 {i+1}:&quot;)
    print(f&quot;日期: {doc.metadata[&apos;date&apos;]}&quot;)
    print(f&quot;内容: {doc.page_content}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;docs: [Document(page_content=&apos;2024年10月：Python 3.13发布，性能提升显著，新增了JIT编译器&apos;, metadata={&apos;date&apos;: &apos;2024-10-01&apos;, &apos;source&apos;: &apos;python_release_notes&apos;}), Document(page_content=&apos;2023年1月：Python 3.11发布，引入了新的语法特性和性能改进&apos;, metadata={&apos;date&apos;: &apos;2023-01-01&apos;, &apos;source&apos;: &apos;python_release_notes&apos;})]

检索结果（按时间加权）:

结果 1:
日期: 2024-10-01
内容: 2024年10月：Python 3.13发布，性能提升显著，新增了JIT编译器

结果 2:
日期: 2023-01-01
内容: 2023年1月：Python 3.11发布，引入了新的语法特性和性能改进
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;5.5 衰减率参数调优&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class AdvancedTimeRetriever:
    &quot;&quot;&quot;高级时间检索器（支持不同衰减策略）&quot;&quot;&quot;
    
    def __init__(self, vectorstore):
        self.vectorstore = vectorstore
        
    def retrieve_with_decay(self, query: str, decay_rate: float = 0.01, k: int = 3):
        &quot;&quot;&quot;使用指定衰减率检索&quot;&quot;&quot;
        retriever = TimeWeightedVectorStoreRetriever(
            vectorstore=self.vectorstore,
            decay_rate=decay_rate,
            k=k
        )
        return retriever.get_relevant_documents(query)
    
    def compare_decay_rates(self, query: str):
        &quot;&quot;&quot;比较不同衰减率的效果&quot;&quot;&quot;
        decay_rates = [0.001, 0.01, 0.1]  # 低、中、高衰减率
        
        print(f&quot;查询: &apos;{query}&apos;&quot;)
        print(&quot;=&quot; * 60)
        
        for decay_rate in decay_rates:
            print(f&quot;\n衰减率: {decay_rate}&quot;)
            results = self.retrieve_with_decay(query, decay_rate)
            
            for i, doc in enumerate(results):
                date = doc.metadata.get(&apos;date&apos;, &apos;未知日期&apos;)
                print(f&quot;  {i+1}. {date}: {doc.page_content[:50]}...&quot;)

# 测试不同衰减率
advanced_retriever = AdvancedTimeRetriever(vectorstore)

# 比较不同衰减率的效果
advanced_retriever.compare_decay_rates(&quot;Python最新特性&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;查询: &apos;Python最新特性&apos;
============================================================

衰减率: 0.001
  1. 2024-10-01: 2024年10月：Python 3.13发布，性能提升显著...
  2. 2023-01-01: 2023年1月：Python 3.11发布，引入了新的语法...

衰减率: 0.01
  1. 2024-10-01: 2024年10月：Python 3.13发布，性能提升显著...
  2. 2023-01-01: 2023年1月：Python 3.11发布，引入了新的语法...

衰减率: 0.1
  1. 2024-10-01: 2024年10月：Python 3.13发布，性能提升显著...
  2. 2023-01-01: 2023年1月：Python 3.11发布，引入了新的语法...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;5.6 自定义时间衰减函数&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import math
from datetime import datetime
from typing import List, Callable

class CustomTimeWeightedRetriever:
    &quot;&quot;&quot;自定义时间加权检索器&quot;&quot;&quot;
    
    def __init__(self, vectorstore, time_weight_function: Callable = None):
        self.vectorstore = vectorstore
        
        # 默认时间衰减函数（指数衰减）
        if time_weight_function is None:
            self.time_weight_function = self.exponential_decay
        else:
            self.time_weight_function = time_weight_function
    
    def exponential_decay(self, doc_date: str, current_date: str = None) -&gt; float:
        &quot;&quot;&quot;指数衰减函数&quot;&quot;&quot;
        if current_date is None:
            current_date = datetime.now().strftime(&quot;%Y-%m-%d&quot;)
        
        # 计算天数差
        doc_dt = datetime.strptime(doc_date, &quot;%Y-%m-%d&quot;)
        current_dt = datetime.strptime(current_date, &quot;%Y-%m-%d&quot;)
        days_diff = (current_dt - doc_dt).days
        
        # 指数衰减：每30天衰减一半
        decay_factor = 0.5 ** (days_diff / 30)
        return max(decay_factor, 0.1)  # 最小权重0.1
    
    def linear_decay(self, doc_date: str, current_date: str = None) -&gt; float:
        &quot;&quot;&quot;线性衰减函数&quot;&quot;&quot;
        if current_date is None:
            current_date = datetime.now().strftime(&quot;%Y-%m-%d&quot;)
        
        doc_dt = datetime.strptime(doc_date, &quot;%Y-%m-%d&quot;)
        current_dt = datetime.strptime(current_date, &quot;%Y-%m-%d&quot;)
        days_diff = (current_dt - doc_dt).days
        
        # 线性衰减：365天内从1.0衰减到0.1
        if days_diff &amp;#x3C;= 365:
            weight = 1.0 - (0.9 * days_diff / 365)
            return max(weight, 0.1)
        else:
            return 0.1
    
    def retrieve_with_custom_weights(self, query: str, k: int = 3) -&gt; List[Document]:
        &quot;&quot;&quot;使用自定义时间权重检索&quot;&quot;&quot;
        # 基础检索（不考虑时间）
        base_docs = self.vectorstore.similarity_search(query, k=k*2)
        
        # 计算时间权重并排序
        weighted_docs = []
        for doc in base_docs:
            doc_date = doc.metadata.get(&apos;date&apos;)
            if doc_date:
                time_weight = self.time_weight_function(doc_date)
            else:
                time_weight = 0.5  # 默认权重
            
            # 可以结合语义相似度得分
            weighted_docs.append((doc, time_weight))
        
        # 按时间权重排序
        weighted_docs.sort(key=lambda x: x[1], reverse=True)
        
        # 返回前k个结果
        return [doc for doc, weight in weighted_docs[:k]]

# 测试自定义检索器
custom_retriever = CustomTimeWeightedRetriever(vectorstore)

# 使用指数衰减
results_exp = custom_retriever.retrieve_with_custom_weights(&quot;Python发布&quot;)
print(&quot;指数衰减结果:&quot;)
for doc in results_exp:
    print(f&quot;  {doc.metadata[&apos;date&apos;]}: {doc.page_content}&quot;)

# 使用线性衰减
custom_retriever.time_weight_function = custom_retriever.linear_decay
results_linear = custom_retriever.retrieve_with_custom_weights(&quot;Python发布&quot;)
print(&quot;\n线性衰减结果:&quot;)
for doc in results_linear:
    print(f&quot;  {doc.metadata[&apos;date&apos;]}: {doc.page_content}&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;5.7 时间衰减策略对比&lt;/h3&gt;
&lt;p&gt;| 衰减策略 | 公式 | 特点 | 适用场景 |
|----------|------|------|----------|
| &lt;strong&gt;指数衰减&lt;/strong&gt; | &lt;code&gt;weight = base^(days/interval)&lt;/code&gt; | 前期衰减快，后期平缓 | 新闻、社交媒体 |
| &lt;strong&gt;线性衰减&lt;/strong&gt; | &lt;code&gt;weight = 1 - (衰减率 × days)&lt;/code&gt; | 均匀衰减，易于控制 | 技术文档、研究报告 |
| &lt;strong&gt;阶梯衰减&lt;/strong&gt; | 按时间段分段设置权重 | 离散化处理，简单明了 | 法律法规、政策文件 |&lt;/p&gt;
&lt;h3&gt;5.8 实际应用场景&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;1. 新闻检索系统&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 新闻文档示例
news_docs = [
    Document(
        page_content=&quot;今日股市大涨，科技股领涨&quot;,
        metadata={&quot;date&quot;: &quot;2024-11-20&quot;, &quot;category&quot;: &quot;财经&quot;}
    ),
    Document(
        page_content=&quot;上周市场回顾：整体平稳&quot;,
        metadata={&quot;date&quot;: &quot;2024-11-13&quot;, &quot;category&quot;: &quot;财经&quot;}
    )
]

# 高衰减率确保最新新闻优先
news_retriever = TimeSensitiveRetriever(vectorstore, decay_rate=0.1)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;2. 技术文档检索&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 技术文档示例
tech_docs = [
    Document(
        page_content=&quot;React 18新特性：并发渲染&quot;,
        metadata={&quot;date&quot;: &quot;2024-06-01&quot;, &quot;framework&quot;: &quot;React&quot;}
    ),
    Document(
        page_content=&quot;React 17版本特性介绍&quot;,
        metadata={&quot;date&quot;: &quot;2023-03-01&quot;, &quot;framework&quot;: &quot;React&quot;}
    )
]

# 中等衰减率平衡新旧信息
tech_retriever = TimeSensitiveRetriever(vectorstore, decay_rate=0.01)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;3. 学术论文检索&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 学术论文示例
paper_docs = [
    Document(
        page_content=&quot;2024年最新AI研究成果&quot;,
        metadata={&quot;date&quot;: &quot;2024-10-01&quot;, &quot;field&quot;: &quot;人工智能&quot;}
    ),
    Document(
        page_content=&quot;经典机器学习算法综述&quot;,
        metadata={&quot;date&quot;: &quot;2020-05-01&quot;, &quot;field&quot;: &quot;机器学习&quot;}
    )
]

# 低衰减率重视经典文献
paper_retriever = TimeSensitiveRetriever(vectorstore, decay_rate=0.001)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;时间衰减检索技术通过智能平衡信息的新鲜度和相关性，为时效性敏感的应用场景提供了重要价值。合理配置衰减参数和策略，可以显著提升检索系统的实用性和用户体验。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;综合实战：构建高级RAG系统&lt;/h2&gt;
&lt;p&gt;在前面的章节中，我们深入探讨了各种高级索引和检索技术。现在，让我们将这些技术整合到一个完整的、生产就绪的高级RAG系统中。&lt;/p&gt;
&lt;h3&gt;系统架构设计&lt;/h3&gt;
&lt;p&gt;我们的高级RAG系统采用模块化设计，包含以下核心组件：&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;系统架构图：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;用户查询
    ↓
查询分析器 → 语义路由
    ↓
多策略检索引擎
├─ 基础向量检索
├─ 压缩检索
├─ 嵌入过滤
├─ 时间加权检索
└─ 混合检索
    ↓
结果融合器
    ↓
上下文压缩器
    ↓
答案生成器
    ↓
最终响应
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;核心实现代码&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.vectorstores import Chroma
from langchain.schema import Document
from typing import Dict, List, Optional
from dataclasses import dataclass
from vllm import SamplingParams
import numpy as np
import asyncio
from datetime import datetime

@dataclass
class RAGConfig:
    &quot;&quot;&quot;RAG系统配置&quot;&quot;&quot;
    chunk_size: int = 1000
    chunk_overlap: int = 200
    similarity_threshold: float = 0.7
    max_retrieved_docs: int = 5
    enable_compression: bool = True
    enable_time_weighting: bool = False
    cache_enabled: bool = True

class AdvancedRAGSystem:
    &quot;&quot;&quot;高级RAG系统（生产级实现）&quot;&quot;&quot;
    
    def __init__(self, embeddings, llm, config: RAGConfig = None):
        self.embeddings = embeddings
        self.llm = llm
        self.config = config or RAGConfig()
        
        # 初始化向量存储
        self.vectorstore = Chroma(
            collection_name=&quot;advanced_rag&quot;,
            persist_directory=&quot;./chroma_db&quot;,
            embedding_function=embeddings
        )
        
        # 基础检索器
        self.base_retriever = self.vectorstore.as_retriever(
            search_kwargs={&quot;k&quot;: self.config.max_retrieved_docs}
        )
        
        # 缓存系统
        self.cache = {} if self.config.cache_enabled else None
        
        print(&quot;🚀 高级RAG系统初始化完成&quot;)
    
    def index_document(self, content: str, metadata: Dict = None):
        &quot;&quot;&quot;智能索引文档&quot;&quot;&quot;
        doc = Document(
            page_content=content,
            metadata=metadata or {}
        )
        
        self.vectorstore.add_documents([doc])
        print(&quot;✅ 文档已索引&quot;)
    
    def batch_index(self, documents: List[Dict]):
        &quot;&quot;&quot;批量索引文档&quot;&quot;&quot;
        docs = []
        for doc_data in documents:
            doc = Document(
                page_content=doc_data[&apos;content&apos;],
                metadata=doc_data.get(&apos;metadata&apos;, {})
            )
            docs.append(doc)
        
        # 分批处理避免内存溢出
        batch_size = 50
        for i in range(0, len(docs), batch_size):
            batch = docs[i:i+batch_size]
            self.vectorstore.add_documents(batch)
            print(f&quot;✅ 已索引批次 {i//batch_size + 1}/{(len(docs)-1)//batch_size + 1}&quot;)
    
    def query_with_compression(self, question: str, k: int = 3):
        &quot;&quot;&quot;使用压缩检索器&quot;&quot;&quot;
        from vllm import SamplingParams
        
        # 基础检索
        base_docs = self.base_retriever.get_relevant_documents(question, k=k*2)
        
        # 压缩文档
        compressed_docs = []
        for doc in base_docs:
            prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
						你是一个文档压缩助手。请从以下文本中提取与用户查询相关的关键信息。
						
						请只返回提取的关键信息，不要添加任何解释或评论。&amp;#x3C;|im_end|&gt;
						&amp;#x3C;|im_start|&gt;user
						查询：{question}
						
						文本内容：
						{doc.page_content}
						
						请提取与查询相关的关键信息：&amp;#x3C;|im_end|&gt;
						&amp;#x3C;|im_start|&gt;assistant
						&quot;&quot;&quot;
            
            sampling_params = SamplingParams(
                temperature=0.1,
                top_p=0.9,
                max_tokens=300,
                stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
            )
            
            outputs = self.llm.generate([prompt], sampling_params)
            if outputs and outputs[0].outputs:
                compressed_content = outputs[0].outputs[0].text.strip()
                compressed_doc = Document(
                    page_content=compressed_content,
                    metadata=doc.metadata
                )
                compressed_docs.append(compressed_doc)
        
        return compressed_docs[:k]
    
    def query_with_embedding_filter(self, question: str, k: int = 3, similarity_threshold: float = 0.7):
        &quot;&quot;&quot;使用嵌入过滤器&quot;&quot;&quot;
        from sklearn.metrics.pairwise import cosine_similarity
        import numpy as np
        
        # 基础检索
        base_docs = self.base_retriever.get_relevant_documents(question, k=k*3)
        
        # 计算查询嵌入
        query_embedding = self.embeddings.embed_query(question)
        
        # 过滤相似度低的文档
        filtered_docs = []
        for doc in base_docs:
            doc_embedding = self.embeddings.embed_query(doc.page_content)
            similarity = cosine_similarity([query_embedding], [doc_embedding])[0][0]
            
            if similarity &gt;= similarity_threshold:
                filtered_docs.append(doc)
        
        return filtered_docs[:k]
    
    def query_with_pipeline(self, question: str, k: int = 3):
        &quot;&quot;&quot;使用管道压缩（压缩+过滤）&quot;&quot;&quot;
        # 先过滤
        filtered_docs = self.query_with_embedding_filter(question, k=k*2, similarity_threshold=0.6)
        
        # 再压缩
        compressed_docs = []
        for doc in filtered_docs:
            prompt = f&quot;从以下文本中提取与&apos;{question}&apos;相关的关键信息：\n\n{doc.page_content}\n\n关键信息：&quot;
            
            from vllm import SamplingParams
            sampling_params = SamplingParams(
                temperature=0.1,
                top_p=0.9,
                max_tokens=200,
                stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
            )
            
            outputs = self.llm.generate([prompt], sampling_params)
            if outputs and outputs[0].outputs:
                compressed_content = outputs[0].outputs[0].text.strip()
                compressed_doc = Document(
                    page_content=compressed_content,
                    metadata=doc.metadata
                )
                compressed_docs.append(compressed_doc)
        
        return compressed_docs[:k]
    
    def hybrid_query(self, question: str, k: int = 3):
        &quot;&quot;&quot;混合检索策略&quot;&quot;&quot;
        # 并行执行多种检索策略
        strategies = {
            &apos;basic&apos;: self.base_retriever.get_relevant_documents(question, k=k),
            &apos;compressed&apos;: self.query_with_compression(question, k=k),
            &apos;filtered&apos;: self.query_with_embedding_filter(question, k=k)
        }
        
        # 结果融合（基于相似度得分）
        all_docs = []
        for strategy_name, docs in strategies.items():
            for doc in docs:
                all_docs.append((doc, strategy_name))
        
        # 去重并排序
        seen_content = set()
        unique_docs = []
        
        for doc, strategy in all_docs:
            content_hash = hash(doc.page_content)
            if content_hash not in seen_content:
                seen_content.add(content_hash)
                unique_docs.append(doc)
        
        return unique_docs[:k]
    
    def query(
        self,
        question: str,
        method: str = &quot;hybrid&quot;,  # &quot;basic&quot;, &quot;compression&quot;, &quot;filter&quot;, &quot;pipeline&quot;, &quot;hybrid&quot;
        k: int = 3
    ):
        &quot;&quot;&quot;综合检索方法&quot;&quot;&quot;
        print(f&quot;❓ 查询: {question}&quot;)
        print(f&quot;🔧 方法: {method}&quot;)
        
        # 检查缓存
        cache_key = f&quot;{question}_{method}_{k}&quot;
        if self.cache and cache_key in self.cache:
            print(&quot;💾 使用缓存结果&quot;)
            return self.cache[cache_key]
        
        # 选择检索方法
        start_time = datetime.now()
        
        if method == &quot;compression&quot;:
            print(&quot;✂️ 使用压缩检索&quot;)
            docs = self.query_with_compression(question, k)
        elif method == &quot;filter&quot;:
            print(&quot;🎯 使用嵌入过滤&quot;)
            docs = self.query_with_embedding_filter(question, k)
        elif method == &quot;pipeline&quot;:
            print(&quot;⚡ 使用管道压缩&quot;)
            docs = self.query_with_pipeline(question, k)
        elif method == &quot;hybrid&quot;:
            print(&quot;🔄 使用混合检索&quot;)
            docs = self.hybrid_query(question, k)
        else:
            print(&quot;🔍 使用基础检索&quot;)
            docs = self.base_retriever.get_relevant_documents(question, k=k)
        
        retrieval_time = (datetime.now() - start_time).total_seconds()
        print(f&quot;📄 检索到 {len(docs)} 个文档 (耗时: {retrieval_time:.2f}s)&quot;)
        
        # 生成答案
        context = &quot;\n\n&quot;.join([doc.page_content for doc in docs])
        
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
					你是一个智能问答助手。请基于提供的文档内容回答问题。
					
					文档内容：
					{context}&amp;#x3C;|im_end|&gt;
					&amp;#x3C;|im_start|&gt;user
					问题：{question}&amp;#x3C;|im_end|&gt;
					&amp;#x3C;|im_start|&gt;assistant
					&quot;&quot;&quot;
        
        from vllm import SamplingParams
        sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            max_tokens=500,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        answer = outputs[0].outputs[0].text.strip() if outputs and outputs[0].outputs else &quot;未能生成答案&quot;
        
        result = {
            &quot;question&quot;: question,
            &quot;method&quot;: method,
            &quot;documents&quot;: docs,
            &quot;answer&quot;: answer,
            &quot;retrieval_time&quot;: retrieval_time,
            &quot;context_length&quot;: len(context)
        }
        
        # 缓存结果
        if self.cache:
            self.cache[cache_key] = result
        
        return result
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;系统部署与测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 初始化高级RAG系统
config = RAGConfig(
    chunk_size=800,
    chunk_overlap=150,
    similarity_threshold=0.6,
    max_retrieved_docs=4,
    enable_compression=True,
    cache_enabled=True
)

advanced_rag = AdvancedRAGSystem(embeddings, llm, config)

# 索引示例文档
python_content = &quot;&quot;&quot;
Python编程语言基础

函数定义：
使用def关键字定义函数：
def function_name(parameters):
    # 函数体
    return result

函数可以接受参数，也可以返回值。
参数可以有默认值，使用parameter=default_value语法。

示例：
def greet(name=&quot;World&quot;):
    return f&quot;Hello, {name}!&quot;

调用函数：
result = greet(&quot;Alice&quot;)
print(result)  # 输出: Hello, Alice!

列表推导式：
列表推导式是Python中创建列表的简洁方式。
语法：[expression for item in iterable if condition]
示例：squares = [x**2 for x in range(10)]

面向对象编程：
类定义使用class关键字：
class MyClass:
    def __init__(self, name):
        self.name = name
    
    def say_hello(self):
        return f&quot;Hello, {self.name}&quot;
&quot;&quot;&quot;

advanced_rag.index_document(
    content=python_content,
    metadata={&quot;source&quot;: &quot;python_tutorial.md&quot;, &quot;type&quot;: &quot;tutorial&quot;, &quot;date&quot;: &quot;2024-01-01&quot;}
)
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;性能对比测试&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 测试不同检索方法
test_questions = [
    &quot;Python中的函数如何定义？&quot;,
    &quot;什么是列表推导式？&quot;,
    &quot;如何创建Python类？&quot;
]

methods = [&quot;basic&quot;, &quot;compression&quot;, &quot;filter&quot;, &quot;hybrid&quot;]

results = []

for question in test_questions:
    print(f&quot;\n{&apos;=&apos;*60}&quot;)
    print(f&quot;测试问题: {question}&quot;)
    print(&apos;=&apos;*60)
    
    for method in methods:
        result = advanced_rag.query(question, method=method, k=2)
        results.append(result)
        
        print(f&quot;\n{method.upper()} 方法:&quot;)
        print(f&quot;答案: {result[&apos;answer&apos;][:80]}...&quot;)
        print(f&quot;检索时间: {result[&apos;retrieval_time&apos;]:.2f}s&quot;)
        print(f&quot;上下文长度: {result[&apos;context_length&apos;]} 字符&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;测试结果输出：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;============================================================
测试问题: Python中的函数如何定义？
============================================================

BASIC 方法:
答案: 使用def关键字定义函数，函数可以接受参数，也可以返回值。参数可以有默认值...
检索时间: 0.45s
上下文长度: 355 字符

COMPRESSION 方法:
答案: Python函数定义使用def关键字，可以接受参数，也可以返回值。参数可以有默认值...
检索时间: 1.23s
上下文长度: 120 字符

HYBRID 方法:
答案: 使用def关键字定义函数，语法为：def function_name(parameters):。函数可以...
检索时间: 0.89s
上下文长度: 280 字符
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;性能优化策略&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;1. 智能缓存系统&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class SmartCache:
    &quot;&quot;&quot;智能缓存系统&quot;&quot;&quot;
    
    def __init__(self, max_size=1000, ttl=3600):
        self.cache = {}
        self.max_size = max_size
        self.ttl = ttl  # 生存时间（秒）
    
    def get(self, key):
        if key in self.cache:
            data, timestamp = self.cache[key]
            if time.time() - timestamp &amp;#x3C; self.ttl:
                return data
            else:
                del self.cache[key]
        return None
    
    def set(self, key, value):
        if len(self.cache) &gt;= self.max_size:
            # LRU淘汰策略
            oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k][1])
            del self.cache[oldest_key]
        
        self.cache[key] = (value, time.time())
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;2. 异步处理优化&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;async def async_batch_query(self, questions: List[str], method: str = &quot;hybrid&quot;):
    &quot;&quot;&quot;异步批量查询&quot;&quot;&quot;
    semaphore = asyncio.Semaphore(5)  # 控制并发数
    
    async def process_question(question):
        async with semaphore:
            return await asyncio.to_thread(self.query, question, method)
    
    tasks = [process_question(q) for q in questions]
    return await asyncio.gather(*tasks)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;3. 动态参数调优&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;def auto_tune_parameters(self, query_type: str):
    &quot;&quot;&quot;根据查询类型自动调整参数&quot;&quot;&quot;
    tuning_rules = {
        &quot;factual&quot;: {
            &quot;similarity_threshold&quot;: 0.8,
            &quot;k&quot;: 3,
            &quot;enable_compression&quot;: False
        },
        &quot;analytical&quot;: {
            &quot;similarity_threshold&quot;: 0.6,
            &quot;k&quot;: 5,
            &quot;enable_compression&quot;: True
        },
        &quot;creative&quot;: {
            &quot;similarity_threshold&quot;: 0.5,
            &quot;k&quot;: 7,
            &quot;enable_compression&quot;: True
        }
    }
    
    return tuning_rules.get(query_type, tuning_rules[&quot;factual&quot;])
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;通过这个完整的高级RAG系统实现，我们成功整合了前面讨论的所有先进技术，为构建生产级的智能问答应用提供了坚实的基础。系统具有良好的可扩展性、可配置性和监控能力，可以满足不同场景下的需求。&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>RAG实战（六）RAG的评价指标</title><link>https://astro-pure.js.org/blog/rag_blogs/rag_blogs-6</link><guid isPermaLink="true">https://astro-pure.js.org/blog/rag_blogs/rag_blogs-6</guid><description>记录RAG的学习。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;h2&gt;RAG评估标准&lt;/h2&gt;
&lt;p&gt;在探索和优化 RAG（检索增强生成器）的过程中，如何有效评估其性能已经成为关键问题。本章节主要围绕评估方法、RAG 应具备的关键指标、它的核心能力，以及一些常用的评估框架进行讨论。&lt;/p&gt;
&lt;p&gt;主要有两种方法来评估 RAG 的有效性：检索模块评估和生成模块评估（&lt;strong&gt;个人理解，不一定对&lt;/strong&gt;）：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;检索模块：评估 RAG 检索模块的性能通常使用一系列指标，这些指标用于衡量系统（如搜索引擎、推荐系统或信息检索系统）在根据查询或任务排名项目的有效性。这些指标包括命中率 (Hit Rate)、平均排名倒数 (MRR)、归一化折扣累积增益 (NDCG)、精确度 (Precision) 等。&lt;/li&gt;
&lt;li&gt;生成模块：生成模块指的是将检索到的文档与查询相结合，形成增强或合成的输入。这与最终答案或响应的生成不同，后者通常采用端到端的评估方式。生成模块的评估主要关注上下文相关性，即检索到的文档与查询问题的关联度。&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;检索模块的评估&lt;/h2&gt;
&lt;p&gt;检索阶段的评估，核心是衡量 “检索器”​ 找到的文档或文本块（称为“候选文档”）的质量。它不关心答案生成得怎么样，只关心“找得到”和“找得准”。&lt;/p&gt;
&lt;h3&gt;检索阶段的核心评估指标&lt;/h3&gt;
&lt;p&gt;主要指标来源于信息检索领域，经典且有效。&lt;/p&gt;
&lt;h4&gt;1. 命中率（Hit Rate）&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;定义&lt;/strong&gt;：衡量系统是否&lt;strong&gt;至少&lt;/strong&gt;找到了一个正确答案所在的文档。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;关注点&lt;/strong&gt;：检索的&lt;strong&gt;广度&lt;/strong&gt;和&lt;strong&gt;可靠性&lt;/strong&gt;。是否漏掉了关键信息？&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;计算方式&lt;/strong&gt;：&lt;code&gt;（至少检索到一个相关文档的问题数量）/ （总问题数量）&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;特点&lt;/strong&gt;：非0即1，非常直观。是评估检索系统的最低标准。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;例子&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;问题&lt;/strong&gt;：“苹果公司最新款iPhone有哪些新特性？”&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;参考答案&lt;/strong&gt;位于文档A、B、C中。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;场景1&lt;/strong&gt;：检索器返回了文档B、D、E。 → &lt;strong&gt;命中（Hit）&lt;/strong&gt;！因为包含了相关文档B。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;场景2&lt;/strong&gt;：检索器返回了文档D、E、F。 → &lt;strong&gt;未命中（Miss）&lt;/strong&gt;！完全没找到相关文档。&lt;/li&gt;
&lt;li&gt;如果测试100个问题，有85次至少命中一个相关文档，则命中率为 &lt;strong&gt;85%&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;
&lt;h4&gt;2. 平均精度均值（Mean Average Precision, mAP）&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;定义&lt;/strong&gt;：这是一个更&lt;strong&gt;精细&lt;/strong&gt;的指标，同时考虑了&lt;strong&gt;排序质量&lt;/strong&gt;。它不仅关心是否找到了相关文档，还关心找到的相关文档是否排在了前面。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;关注点&lt;/strong&gt;：检索结果的&lt;strong&gt;排序好坏&lt;/strong&gt;。用户通常只看前几条结果。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;计算方式&lt;/strong&gt;：先计算每个问题的平均精度（AP），再对所有问题的AP取平均值。
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;精度（Precision）@K&lt;/strong&gt;：在前K个结果中，相关文档的比例。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;平均精度（AP）&lt;/strong&gt;：在不同召回率水平下的精度平均值（计算略复杂，但可简单理解为对排序好坏的量化）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;例子&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;问题&lt;/strong&gt;：“Python中如何连接MySQL数据库？”&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;相关文档&lt;/strong&gt;是 Doc2 和 Doc5。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;场景1（排序好）&lt;/strong&gt;：检索器返回顺序为 &lt;code&gt;[Doc2, Doc5, Doc7, Doc8]&lt;/code&gt;。
&lt;ul&gt;
&lt;li&gt;前1个结果（Doc2）：精度 = 1/1 = 1.0&lt;/li&gt;
&lt;li&gt;前2个结果（Doc2, Doc5）：精度 = 2/2 = 1.0&lt;/li&gt;
&lt;li&gt;前3个结果（Doc2, Doc5, Doc7）：精度 = 2/3 ≈ 0.67&lt;/li&gt;
&lt;li&gt;AP值会很高（接近1.0）。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;场景2（排序差）&lt;/strong&gt;：检索器返回顺序为 &lt;code&gt;[Doc7, Doc8, Doc2, Doc5]&lt;/code&gt;。
&lt;ul&gt;
&lt;li&gt;前1个结果（Doc7）：精度 = 0/1 = 0&lt;/li&gt;
&lt;li&gt;前2个结果（Doc7, Doc8）：精度 = 0/2 = 0&lt;/li&gt;
&lt;li&gt;前3个结果（Doc7, Doc8, Doc2）：精度 = 1/3 ≈ 0.33&lt;/li&gt;
&lt;li&gt;前4个结果（Doc7, Doc8, Doc2, Doc5）：精度 = 2/4 = 0.5&lt;/li&gt;
&lt;li&gt;AP值会很低。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;mAP&lt;/strong&gt;就是所有问题的AP的平均值。&lt;strong&gt;mAP越高，说明检索器不仅找得准，还排得好。&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h4&gt;3. 归一化折损累计增益（Normalized Discounted Cumulative Gain, nDCG）&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;定义&lt;/strong&gt;：mAP的进阶版，适用于&lt;strong&gt;相关度有分级&lt;/strong&gt;的情况（而mAP通常认为文档只有“相关”和“不相关”两种）。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;关注点&lt;/strong&gt;：&lt;strong&gt;分级相关性&lt;/strong&gt;和&lt;strong&gt;排序位置&lt;/strong&gt;。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;计算方式&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;增益（Gain）&lt;/strong&gt;：给每个文档一个相关性分数（如，非常相关=3，相关=2，有点相关=1，不相关=0）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;折损（Discounted）&lt;/strong&gt;：根据排名位置对增益进行打折，排名越靠后，打折越狠。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;累计增益（CG/DCG）&lt;/strong&gt;：将前K个结果的（折损后）增益加起来。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;归一化（nDCG）&lt;/strong&gt;：用“理想排序”下的DCG进行归一化，值在0到1之间。&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;例子&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;问题&lt;/strong&gt;：“什么是机器学习？”&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;检索结果及真实相关性&lt;/strong&gt;：
&lt;ul&gt;
&lt;li&gt;Doc1（“机器学习经典算法详解”）：&lt;strong&gt;非常相关（3分）&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Doc2（“人工智能发展简史”）：&lt;strong&gt;有点相关（1分）&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;Doc3（“深度学习入门指南”）：&lt;strong&gt;相关（2分）&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;场景1（实际排序）&lt;/strong&gt;：&lt;code&gt;[Doc2（1分）, Doc3（2分）, Doc1（3分）]&lt;/code&gt; → 排序很差，最重要的Doc1排最后。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;场景2（理想排序）&lt;/strong&gt;：&lt;code&gt;[Doc1（3分）, Doc3（2分）, Doc2（1分）]&lt;/code&gt; → 完美排序。&lt;/li&gt;
&lt;li&gt;nDCG会计算实际排序的DCG与理想排序的DCG的比值。场景1的nDCG会远低于场景2。&lt;strong&gt;nDCG越接近1，说明排序质量越高。&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h4&gt;4. 召回率（Recall）&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;定义&lt;/strong&gt;：在检索阶段，指&lt;strong&gt;检索到的相关文档数&lt;/strong&gt;占&lt;strong&gt;所有真实相关文档数&lt;/strong&gt;的比例。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;关注点&lt;/strong&gt;：检索的&lt;strong&gt;全面性&lt;/strong&gt;。是否把所有相关的文档都找出来了？&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;计算方式&lt;/strong&gt;：&lt;code&gt;（检索到的相关文档数）/ （数据集中所有相关文档数）&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;注意&lt;/strong&gt;：这个“召回率”与网页中提到的“上下文召回率”概念相似，但评估对象不同。这里是评估&lt;strong&gt;文档块&lt;/strong&gt;的召回，而“上下文召回率”是评估&lt;strong&gt;答案关键信息点&lt;/strong&gt;的召回。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;例子&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;问题&lt;/strong&gt;：“RAG系统有哪些评估指标？”&lt;/li&gt;
&lt;li&gt;知识库中总共有5个文档与此问题相关（DocA, DocB, DocC, DocD, DocE）。&lt;/li&gt;
&lt;li&gt;检索器只返回了2个文档（DocA, DocC）。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;召回率&lt;/strong&gt; = 2 / 5 = &lt;strong&gt;40%&lt;/strong&gt;。说明系统漏掉了60%的相关资料。&lt;/li&gt;
&lt;/ul&gt;
&lt;hr&gt;
&lt;h2&gt;生成模块的评估&lt;/h2&gt;
&lt;p&gt;生成模块评估是对 RAG 模型对特定输入生成的最终响应进行评估，涉及模型生成的答案与输入查询的相关性和一致性。&lt;/p&gt;
&lt;p&gt;从内容生成的目标来看，评估可分为无标签和有标签的内容评估。无标签内容的评估指标包括答案的准确性、相关性和无害性，而有标签内容的评估指标则包括准确率 (Accuracy) 和精确匹配 (EM)。此外，根据评估方法的不同，端到端评估可分为人工评估和使用大语言模型 (LLM) 的自动评估。总的来说，这些是 RAG 端到端评估的常规方法。特定领域的 RAG 应用还会采用特定的评估指标，如问答任务的精确匹配 (EM)，摘要任务的 UniEval 和 E-F1，以及机器翻译的 BLEU。&lt;/p&gt;
&lt;p&gt;这些指标有助于理解 RAG 在各种特定应用场景中的表现。&lt;/p&gt;
&lt;h2&gt;检索评估 vs. 生成评估&lt;/h2&gt;
&lt;p&gt;| 评估层面 | 核心指标 | 评估对象 | 要解决的问题 |
| :--- | :--- | :--- | :--- |
| &lt;strong&gt;检索阶段&lt;/strong&gt; | &lt;strong&gt;命中率（Hit Rate）&lt;/strong&gt;、&lt;strong&gt;mAP&lt;/strong&gt;、&lt;strong&gt;nDCG&lt;/strong&gt;、&lt;strong&gt;召回率（Recall）&lt;/strong&gt; | 检索器返回的&lt;strong&gt;文档/文本块列表&lt;/strong&gt; | “找得到”吗？“找得全”吗？“排得好”吗？ |
| &lt;strong&gt;生成阶段&lt;/strong&gt; | &lt;strong&gt;答案忠实度&lt;/strong&gt;、&lt;strong&gt;答案正确性&lt;/strong&gt;、&lt;strong&gt;答案相关性&lt;/strong&gt; | LLM生成的&lt;strong&gt;最终答案&lt;/strong&gt; | “答得对”吗？“基于上下文”吗？“答到点子上”吗？ |
| &lt;strong&gt;混合指标（连接检索与生成）&lt;/strong&gt; | &lt;strong&gt;上下文召回率&lt;/strong&gt;、&lt;strong&gt;上下文相关性&lt;/strong&gt; | 检索到的文档块 &lt;strong&gt;+&lt;/strong&gt; 参考答案/问题 | 检索结果是否为生成好答案&lt;strong&gt;打下了基础&lt;/strong&gt;？ |&lt;/p&gt;
&lt;h2&gt;总结&lt;/h2&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;检索指标是原因，答案指标是结果&lt;/strong&gt;。一个检索指标很差的系统（比如命中率低、mAP低），几乎不可能生成出高忠实度、高正确性的答案。&lt;/li&gt;
&lt;li&gt;在实际评估中，通常会&lt;strong&gt;结合使用&lt;/strong&gt;这两类指标：
&lt;ol&gt;
&lt;li&gt;先用&lt;strong&gt;检索指标&lt;/strong&gt;诊断检索器的问题（如Embedding模型不好、分块策略不佳、排序模型失效）。&lt;/li&gt;
&lt;li&gt;再用&lt;strong&gt;答案指标&lt;/strong&gt;评估整个端到端系统的最终效果。&lt;/li&gt;
&lt;/ol&gt;
&lt;/li&gt;
&lt;/ul&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>RAG实战（五）重排序与查询集成</title><link>https://astro-pure.js.org/blog/rag_blogs/rag_blogs-5</link><guid isPermaLink="true">https://astro-pure.js.org/blog/rag_blogs/rag_blogs-5</guid><description>记录RAG的学习。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;p&gt;代码开源&lt;a href=&quot;https://github.com/SoupCola/RAG_Learning&quot;&gt;Github地址&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;RAG重排序与查询集成：提升检索精度的关键技术&lt;/h2&gt;
&lt;p&gt;在前面的章节中，我们学习了如何优化文档索引和检索策略。但是，初次检索的结果往往不够精确。本章将深入探讨如何通过重排序和查询集成技术进一步提升检索质量。&lt;/p&gt;
&lt;h4&gt;为什么要重排序？&lt;/h4&gt;
&lt;p&gt;&lt;strong&gt;向量检索的局限性&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;场景：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;用户查询: &quot;Python中如何处理异常？&quot;

向量检索结果（按相似度排序）:
1. 文档A: &quot;Python异常处理机制...try-except...&quot; (相似度: 0.85)
2. 文档B: &quot;Python中的错误类型...Exception类...&quot; (相似度: 0.83)
3. 文档C: &quot;Python编程基础...变量、函数...&quot; (相似度: 0.82)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;问题分析：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;→ 文档A最相关，排序正确 ✅&lt;/li&gt;
&lt;li&gt;→ 但相似度差异很小（0.85 vs 0.83 vs 0.82）&lt;/li&gt;
&lt;li&gt;→ 向量相似度不能完全反映真实相关性&lt;/li&gt;
&lt;li&gt;→ 文档C不太相关但相似度也不低&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;解决方案：重排序&lt;/strong&gt; 🎯&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;→ 使用更强大的模型重新评估文档相关性&lt;/li&gt;
&lt;li&gt;→ 考虑查询和文档的精确匹配度&lt;/li&gt;
&lt;li&gt;→ 调整排序，确保最相关的文档排在前面&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;本章技术概览&lt;/h3&gt;
&lt;p&gt;| 技术 | 核心功能 | 优势 | 复杂度 | 推荐指数 |
|------|----------|------|--------|----------|
| 交叉编码器重排序 | 精确相关性评估 | 准确度最高 | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
| 倒数排序融合(RRF) | 多结果融合 | 鲁棒性强 | ⭐⭐ | ⭐⭐⭐⭐ |
| 多查询检索 | 多角度查询 | 召回率高 | ⭐⭐ | ⭐⭐⭐⭐ |
| 查询扩展 | 语义扩充 | 覆盖面广 | ⭐⭐ | ⭐⭐⭐ |
| 混合检索 | 向量+关键词 | 全面性好 | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 1: 交叉编码器重排序 - Cross-Encoder Reranking&lt;/h2&gt;
&lt;h3&gt;1.1 核心概念&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;双编码器(Bi-Encoder) vs 交叉编码器(Cross-Encoder):&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;双编码器（用于初始检索）&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;查询 → 编码器 → 查询向量 ──┐
                            ├→ 余弦相似度
文档 → 编码器 → 文档向量 ──┘

优点：快速，可预先计算文档向量
缺点：查询和文档独立编码，无法捕捉细粒度交互
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;交叉编码器（用于重排序）&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;查询 + 文档 → 编码器 → 相关性分数

优点：查询和文档联合编码，更精确
缺点：慢，无法预先计算
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;1.2 工作流程&lt;/h2&gt;
&lt;p&gt;&lt;strong&gt;RAG + 重排序流程&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;1. 初始检索（双编码器）
   → 从大量文档中快速检索top-k个候选
   → 例如：从10000个文档中检索top-50

2. 重排序（交叉编码器）
   → 对top-k个候选重新评分
   → 更精确地排序

3. 返回最终结果
   → 返回重排序后的top-n个文档
   → 例如：返回最相关的top-5
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.3 实现本地交叉编码器模型&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from sentence_transformers import CrossEncoder
from typing import List
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
import chromadb
from chromadb.config import Settings

class LocalCrossEncoderReranker:
    &quot;&quot;&quot;使用本地交叉编码器模型重排序&quot;&quot;&quot;
    
    def __init__(self, embeddings, model_name: str = &quot;./Models/ms-marco-MiniLM-L-6-v2/cross-encoder/ms-marco-MiniLM-L6-v2&quot;):
        self.embeddings = embeddings
        
        ## 加载交叉编码器模型
        print(f&quot;📥 加载交叉编码器模型: {model_name}&quot;)
        self.cross_encoder = CrossEncoder(model_name)
        self.persist_directory: str = &quot;./chroma_db&quot;

        client = chromadb.PersistentClient(path=self.persist_directory)  
        try:
            client.delete_collection(&quot;local_rerank&quot;)
            print(&quot;🗑️ 已删除旧的 Chroma 集合 &apos;local_rerank&apos;&quot;)
        except ValueError:
            print(&quot;🆕 未发现旧集合，将创建新的 &apos;local_rerank&apos; 集合&quot;)
        
        ## 创建新的向量存储
        self.vectorstore = Chroma(
            collection_name=&quot;local_rerank&quot;,
            embedding_function=embeddings,
            persist_directory=self.persist_directory 
        )
    
    def add_documents(self, documents: List[Document]):
        &quot;&quot;&quot;添加文档&quot;&quot;&quot;
        self.vectorstore.add_documents(documents)
    
    def retrieve_and_rerank(
        self,
        query: str,
        initial_k: int = 20,
        final_k: int = 5
    ):
        &quot;&quot;&quot;检索并重排序&quot;&quot;&quot;
        ## 1. 初始检索
        initial_docs = self.vectorstore.similarity_search(query, k=initial_k)
        
        ## 2. 准备查询-文档对
        query_doc_pairs = [
            [query, doc.page_content] for doc in initial_docs
        ]
        
        ## 3. 使用交叉编码器计算相关性分数
        print(f&quot;🎯 使用交叉编码器重新评分...&quot;)
        scores = self.cross_encoder.predict(query_doc_pairs)
        
        ## 4. 根据分数排序
        scored_docs = [
            {&apos;document&apos;: doc, &apos;score&apos;: score}
            for doc, score in zip(initial_docs, scores)
        ]
        
        ## 按分数降序排序
        scored_docs.sort(key=lambda x: x[&apos;score&apos;], reverse=True)
        
        ## 5. 返回top-k
        reranked_docs = scored_docs[:final_k]
        
        print(&quot;\n重排序结果:&quot;)
        for i, item in enumerate(reranked_docs):
            print(f&quot;{i+1}. [得分: {item[&apos;score&apos;]:.4f}] {item[&apos;document&apos;].page_content[:100]}...&quot;)
        
        return reranked_docs
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.4 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 准备测试文档
documents = [
    Document(
        page_content=&quot;&quot;&quot;
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print(f&quot;发生错误: {e}&quot;)
        
        可以捕获特定异常类型，也可以使用finally子句。
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_exceptions.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        Python错误和异常类型
        
        Python有多种内置异常类型：
        - ValueError: 值错误
        - TypeError: 类型错误
        - KeyError: 键错误
        - IndexError: 索引错误
        
        所有异常都继承自Exception类。
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_error_types.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        Python编程基础教程
        
        本教程涵盖：
        - 变量和数据类型
        - 控制流语句
        - 函数定义
        - 模块导入
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_basics.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        如何在Python中优雅地处理错误
        
        最佳实践：
        1. 只捕获你能处理的异常
        2. 使用具体的异常类型而不是Exception
        3. 提供有用的错误信息
        4. 适当时使用finally清理资源
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_error_best_practices.md&quot;}
    ),
]

## 使用示例
local_reranker = LocalCrossEncoderReranker(embeddings)
local_reranker.add_documents(documents)

query = &quot;Python中如何处理异常？&quot;
results = local_reranker.retrieve_and_rerank(query, initial_k=10, final_k=3)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📥 加载交叉编码器模型: ./Models/ms-marco-MiniLM-L-6-v2/cross-encoder/ms-marco-MiniLM-L6-v2
🗑️ 已删除旧的 Chroma 集合 &apos;local_rerank&apos;

🎯 使用交叉编码器重新评分...

重排序结果:
1. [得分: 8.3246] 
        如何在Python中优雅地处理错误
        
        最佳实践：
        1. 只捕获你能处理的异常
        2. 使用具体的异常类型而不是Excep...
2. [得分: 7.6097] 
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation(...
3. [得分: 7.5047] 
        Python错误和异常类型
        
        Python有多种内置异常类型：
        - ValueError: 值错误
        - TypeErr...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.5 常用交叉编码器模型&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;RERANKER_MODELS = {
    &quot;small_fast&quot;: {
        &quot;name&quot;: &quot;cross-encoder/ms-marco-TinyBERT-L-2-v2&quot;,
        &quot;params&quot;: &quot;~4M&quot;,
        &quot;speed&quot;: &quot;very fast&quot;,
        &quot;quality&quot;: &quot;good&quot;
    },
    &quot;balanced&quot;: {
        &quot;name&quot;: &quot;cross-encoder/ms-marco-MiniLM-L-6-v2&quot;,
        &quot;params&quot;: &quot;~22M&quot;,
        &quot;speed&quot;: &quot;fast&quot;,
        &quot;quality&quot;: &quot;very good&quot;
    },
    &quot;high_quality&quot;: {
        &quot;name&quot;: &quot;cross-encoder/ms-marco-MiniLM-L-12-v2&quot;,
        &quot;params&quot;: &quot;~33M&quot;,
        &quot;speed&quot;: &quot;medium&quot;,
        &quot;quality&quot;: &quot;excellent&quot;
    },
    &quot;multilingual&quot;: {
        &quot;name&quot;: &quot;cross-encoder/mmarco-mMiniLMv2-L12-H384-v1&quot;,
        &quot;params&quot;: &quot;~118M&quot;,
        &quot;speed&quot;: &quot;medium&quot;,
        &quot;quality&quot;: &quot;excellent (多语言)&quot;
    }
}

def choose_reranker(priority: str = &quot;balanced&quot;):
    &quot;&quot;&quot;选择合适的重排序模型&quot;&quot;&quot;
    model_info = RERANKER_MODELS.get(priority, RERANKER_MODELS[&quot;balanced&quot;])
    
    print(f&quot;📊 选择模型: {model_info[&apos;name&apos;]}&quot;)
    print(f&quot;   参数量: {model_info[&apos;params&apos;]}&quot;)
    print(f&quot;   速度: {model_info[&apos;speed&apos;]}&quot;)
    print(f&quot;   质量: {model_info[&apos;quality&apos;]}&quot;)
    
    return CrossEncoder(model_info[&apos;name&apos;])
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;1.6 技术优势与适用场景&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;优势：&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;精确度显著提升&lt;/strong&gt;：交叉编码器能捕捉查询和文档的细粒度交互&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;解决语义鸿沟&lt;/strong&gt;：弥补向量检索的局限性&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;灵活配置&lt;/strong&gt;：可根据需求调整初始检索和重排序的比例&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;质量可控&lt;/strong&gt;：通过分数阈值控制返回文档的质量&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;适用场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;高精度要求的问答系统&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;法律、医疗等专业领域检索&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;需要精确匹配的搜索应用&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;对检索质量要求极高的生产环境&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;性能考虑：&lt;/strong&gt; ⚠️&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;计算开销&lt;/strong&gt;：重排序会增加响应时间&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;资源需求&lt;/strong&gt;：需要加载额外的模型&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;批量优化&lt;/strong&gt;：建议批量处理以提高效率&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;交叉编码器重排序技术通过更精细的相关性评估，显著提升了RAG系统的检索精度，是构建高质量智能问答系统的关键技术之一。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 2: 倒数排序融合 - Reciprocal Rank Fusion (RRF)&lt;/h2&gt;
&lt;h3&gt;2.1 核心概念&lt;/h3&gt;
&lt;p&gt;倒数排序融合(RRF)是一种融合多个排序列表的算法，它不需要知道具体的分数，只需要排名。这种方法特别适合融合来自不同检索系统的结果。&lt;/p&gt;
&lt;h3&gt;2.2 RRF算法原理&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;RRF公式&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;对于文档d，其RRF分数为：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;RRF(d) = Σ (1 / (k + rank_i(d)))
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其中:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;rank_i(d)&lt;/code&gt;: 文档d在第i个排序列表中的排名&lt;/li&gt;
&lt;li&gt;&lt;code&gt;k&lt;/code&gt;: 常数，通常取60&lt;/li&gt;
&lt;li&gt;&lt;code&gt;Σ&lt;/code&gt;: 对所有排序列表求和&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;示例演示&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;假设有2个排序列表，k=60:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;列表1: [DocA, DocB, DocC]  (DocA排名1, DocB排名2, DocC排名3)
列表2: [DocC, DocA, DocD]  (DocC排名1, DocA排名2, DocD排名3)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;RRF分数计算:&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;DocA: 1/(60+1) + 1/(60+2) = 0.0164 + 0.0161 = 0.0325
DocB: 1/(60+2) + 0        = 0.0161
DocC: 1/(60+3) + 1/(60+1) = 0.0159 + 0.0164 = 0.0323
DocD: 0        + 1/(60+3) = 0.0159
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;最终排序: &lt;code&gt;DocA &gt; DocC &gt; DocB &gt; DocD&lt;/code&gt;&lt;/p&gt;
&lt;h3&gt;2.3 实现RRF融合器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from typing import List, Dict
from collections import defaultdict

class ReciprocalRankFusion:
    &quot;&quot;&quot;倒数排序融合&quot;&quot;&quot;
    
    def __init__(self, k: int = 60):
        &quot;&quot;&quot;
        Args:
            k: RRF常数，通常取60
        &quot;&quot;&quot;
        self.k = k
    
    def fuse(self, ranked_lists: List[List[Document]]) -&gt; List[Dict]:
        &quot;&quot;&quot;
        融合多个排序列表
        
        Args:
            ranked_lists: 多个文档排序列表
            
        Returns:
            融合后的文档列表（包含RRF分数）
        &quot;&quot;&quot;
        ## 存储每个文档的RRF分数
        doc_scores = defaultdict(float)
        doc_map = {}  ## 文档ID到文档对象的映射
        
        ## 遍历每个排序列表
        for ranked_list in ranked_lists:
            for rank, doc in enumerate(ranked_list, start=1):
                ## 使用page_content作为文档唯一标识
                doc_id = id(doc)
                
                ## 计算RRF分数
                rrf_score = 1.0 / (self.k + rank)
                doc_scores[doc_id] += rrf_score
                
                ## 保存文档对象
                if doc_id not in doc_map:
                    doc_map[doc_id] = doc
        
        ## 按RRF分数排序
        sorted_docs = sorted(
            doc_scores.items(),
            key=lambda x: x[1],
            reverse=True
        )
        
        ## 构建结果
        fused_results = [
            {
                &apos;document&apos;: doc_map[doc_id],
                &apos;rrf_score&apos;: score
            }
            for doc_id, score in sorted_docs
        ]
        
        return fused_results
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.4 使用示例&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 创建RRF融合器
rrf = ReciprocalRankFusion(k=60)

## 模拟多个检索器的结果
## 检索器1: 向量相似度检索
retriever1_results = [
    Document(page_content=&quot;文档A: Python异常处理...&quot;),
    Document(page_content=&quot;文档B: 错误类型...&quot;),
    Document(page_content=&quot;文档C: 编程基础...&quot;)
]

## 检索器2: BM25关键词检索
retriever2_results = [
    Document(page_content=&quot;文档C: 编程基础...&quot;),
    Document(page_content=&quot;文档A: Python异常处理...&quot;),
    Document(page_content=&quot;文档D: 最佳实践...&quot;)
]

## 融合结果
fused = rrf.fuse([retriever1_results, retriever2_results])

print(&quot;RRF融合结果:&quot;)
for i, item in enumerate(fused):
    print(f&quot;{i+1}. [RRF分数: {item[&apos;rrf_score&apos;]:.4f}] {item[&apos;document&apos;].page_content[:50]}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;RRF融合结果:
1. [RRF分数: 0.0164] 文档A: Python异常处理......
2. [RRF分数: 0.0164] 文档C: 编程基础......
3. [RRF分数: 0.0161] 文档B: 错误类型......
4. [RRF分数: 0.0161] 文档A: Python异常处理......
5. [RRF分数: 0.0159] 文档C: 编程基础......
6. [RRF分数: 0.0159] 文档D: 最佳实践......
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.5 实现完整的RRF检索器&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.retrievers import BM25Retriever, EnsembleRetriever
from typing import List, Dict

class RRFEnsembleRetriever:
    &quot;&quot;&quot;使用RRF的集成检索器&quot;&quot;&quot;
    
    def __init__(self, embeddings, documents: List[Document]):
        self.embeddings = embeddings
        self.persist_directory: str = &quot;./chroma_db&quot;
        
        ## 1. 向量检索器
        self.vectorstore = Chroma(
            collection_name=&quot;rrf_ensemble&quot;,
            embedding_function=embeddings,
            persist_directory=self.persist_directory 
        )
        self.vectorstore.add_documents(documents)
        self.vector_retriever = self.vectorstore.as_retriever(search_kwargs={&quot;k&quot;: 10})
        
        ## 2. BM25关键词检索器
        self.bm25_retriever = BM25Retriever.from_documents(documents)
        self.bm25_retriever.k = 10
        
        ## 3. RRF融合器
        self.rrf = ReciprocalRankFusion(k=60)
    
    def retrieve(self, query: str, k: int = 5) -&gt; List[Dict]:
        &quot;&quot;&quot;使用RRF融合多个检索器的结果&quot;&quot;&quot;
        print(f&quot;🔍 查询: {query}\n&quot;)
        
        ## 1. 向量检索
        print(&quot;📊 向量检索...&quot;)
        vector_results = self.vector_retriever.get_relevant_documents(query)
        print(f&quot;  → 检索到 {len(vector_results)} 个文档&quot;)
        
        ## 2. BM25检索
        print(&quot;🔤 BM25关键词检索...&quot;)
        bm25_results = self.bm25_retriever.get_relevant_documents(query)
        print(f&quot;  → 检索到 {len(bm25_results)} 个文档&quot;)
        
        ## 3. RRF融合
        print(&quot;\n🔀 RRF融合...&quot;)
        fused_results = self.rrf.fuse([vector_results, bm25_results])
        
        ## 返回top-k
        return fused_results[:k]
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.6 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 准备测试文档
documents = [
    Document(
        page_content=&quot;&quot;&quot;
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print(f&quot;发生错误: {e}&quot;)
        
        可以捕获特定异常类型，也可以使用finally子句。
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_exceptions.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        Python错误和异常类型
        
        Python有多种内置异常类型：
        - ValueError: 值错误
        - TypeError: 类型错误
        - KeyError: 键错误
        - IndexError: 索引错误
        
        所有异常都继承自Exception类。
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_error_types.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        如何在Python中优雅地处理错误
        
        最佳实践：
        1. 只捕获你能处理的异常
        2. 使用具体的异常类型而不是Exception
        3. 提供有用的错误信息
        4. 适当时使用finally清理资源
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_error_best_practices.md&quot;}
    ),
]

## 使用示例
rrf_retriever = RRFEnsembleRetriever(embeddings, documents)

results = rrf_retriever.retrieve(&quot;Python异常处理&quot;, k=3)

print(&quot;\n最终结果:&quot;)
for i, item in enumerate(results):
    print(f&quot;\n{i+1}. [RRF分数: {item[&apos;rrf_score&apos;]:.4f}]&quot;)
    print(f&quot;   {item[&apos;document&apos;].page_content[:150]}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;🔍 查询: Python异常处理

📊 向量检索...
  → 检索到 10 个文档
🔤 BM25关键词检索...
  → 检索到 4 个文档

🔀 RRF融合...

最终结果:

1. [RRF分数: 0.0164]
   
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print...

2. [RRF分数: 0.0164]
   
        如何在Python中优雅地处理错误
        
        最佳实践：
        1. 只捕获你能处理的异常
        2. 使用具体的异常类型而不是Exception
        3. 提供有用的错误信息
        4. 适当时使用finally清...

3. [RRF分数: 0.0161]
   
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;2.7 技术优势与适用场景&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;优势：&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;无需分数归一化&lt;/strong&gt;：直接使用排名，避免不同系统的分数尺度问题&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;鲁棒性强&lt;/strong&gt;：对单个检索器的异常结果不敏感&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;简单高效&lt;/strong&gt;：算法简单，计算开销小&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;可扩展性好&lt;/strong&gt;：容易集成新的检索系统&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;适用场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;多检索系统融合&lt;/strong&gt;：结合向量检索、关键词检索等不同方法&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;联邦搜索&lt;/strong&gt;：融合来自不同数据源的结果&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;专家系统&lt;/strong&gt;：结合不同专业领域的检索器&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;容错系统&lt;/strong&gt;：需要鲁棒性的生产环境&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;性能考虑：&lt;/strong&gt; ⚠️&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;排名质量依赖&lt;/strong&gt;：结果质量取决于各个检索器的排名质量&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;常数选择&lt;/strong&gt;：k值需要根据具体场景调整&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;重复文档处理&lt;/strong&gt;：需要处理不同检索器返回的相同文档&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;RRF技术通过简单而有效的方法融合多个检索系统的结果，显著提升了检索的鲁棒性和准确性，是构建复杂检索系统的关键技术之一。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 3: 多查询检索 - Multi-Query Retrieval&lt;/h2&gt;
&lt;h3&gt;3.1 核心概念&lt;/h3&gt;
&lt;p&gt;多查询检索是一种通过生成多个查询变体来提高检索召回率的先进技术。其核心思路是：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;从单个用户查询生成多个相似但不同的查询&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;对每个查询分别进行检索&lt;/strong&gt;&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;融合所有检索结果&lt;/strong&gt;&lt;/li&gt;
&lt;/ol&gt;
&lt;h3&gt;3.2 为什么需要多查询？&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;问题：单一查询可能不够全面&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;用户查询: &quot;Python如何读取文件？&quot;

## 可能的相关文档:
- &quot;Python文件读取open()函数&quot;  ✅ 匹配
- &quot;读写文件的最佳实践&quot;        ❌ 可能不匹配（没有&quot;Python&quot;）
- &quot;使用pathlib处理文件路径&quot;  ❌ 可能不匹配（没有&quot;读取&quot;）
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;解决方案：生成多个查询变体&lt;/strong&gt; ✅&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;原始查询: &quot;Python如何读取文件？&quot;

生成的查询变体:
1. &quot;在Python中打开和读取文件&quot;
2. &quot;Python file I/O操作&quot;
3. &quot;使用open()函数读取文件内容&quot;
4. &quot;Python文件处理方法&quot;

→ 不同变体可能匹配不同的相关文档
→ 融合结果，提高召回率
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.3 实现使用LLM生成查询变体&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.vectorstores import Chroma
from langchain.schema import Document
from typing import List, Dict
from vllm import SamplingParams

class MultiQueryRetriever:
    &quot;&quot;&quot;多查询检索器（适配本地模型）&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm):
        self.vectorstore = vectorstore
        self.llm = llm
    
    def generate_queries(self, question: str) -&gt; List[str]:
        &quot;&quot;&quot;生成查询变体&quot;&quot;&quot;
        ## 使用ChatML格式
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个搜索查询生成助手。请为用户的查询生成3个不同的变体，这些变体表达相同的意图但用词不同。
            
            请每行输出一个查询，不要编号。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            用户查询：{question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.9,
            max_tokens=200,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        response = outputs[0].outputs[0].text.strip() if outputs and outputs[0].outputs else &quot;&quot;
        
        ## 解析生成的查询
        queries = [q.strip() for q in response.split(&apos;\n&apos;) if q.strip()]
        
        ## 加入原始查询
        all_queries = [question] + queries
        
        print(f&quot;📝 生成了 {len(all_queries)} 个查询变体&quot;)
        return all_queries
    
    def simple_rrf_fusion(self, all_results: List[List[Document]], k: int = 5) -&gt; List[Document]:
        &quot;&quot;&quot;简单的RRF融合&quot;&quot;&quot;
        ## 简单的文档去重和排序
        seen_content = set()
        unique_docs = []
        
        for results in all_results:
            for doc in results:
                if doc.page_content not in seen_content:
                    seen_content.add(doc.page_content)
                    unique_docs.append(doc)
        
        return unique_docs[:k]
    
    def retrieve(self, question: str, k: int = 5) -&gt; List[Document]:
        &quot;&quot;&quot;多查询检索&quot;&quot;&quot;
        ## 1. 生成查询变体
        queries = self.generate_queries(question)
        
        ## 2. 对每个查询进行检索
        all_results = []
        for query in queries:
            results = self.vectorstore.similarity_search(query, k=k)
            all_results.append(results)

        ## 3. 简单的结果融合
        fused_results = self.simple_rrf_fusion(all_results, k)
        
        print(f&quot;✅ 最终检索到 {len(fused_results)} 个文档&quot;)
        return fused_results
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.4 实际应用演示&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 创建向量数据库
vectorstore = Chroma(
    collection_name=&quot;multi_query&quot;,
    persist_directory=&quot;./chroma_db&quot;,
    embedding_function=embeddings
)

## 准备测试文档
documents = [
    Document(
        page_content=&quot;&quot;&quot;
        Python文件操作完整指南
        
        使用open()函数读取文件：
        with open(&apos;file.txt&apos;, &apos;r&apos;) as f:
            content = f.read()
        
        写入文件：
        with open(&apos;file.txt&apos;, &apos;w&apos;) as f:
            f.write(&apos;Hello, World!&apos;)
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_file_operations.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        Python I/O操作最佳实践
        
        文件处理建议：
        1. 总是使用with语句确保文件正确关闭
        2. 处理文件编码问题
        3. 使用pathlib进行路径操作
        4. 处理大文件时使用分块读取
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_io_best_practices.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        Python pathlib模块使用
        
        pathlib提供面向对象的文件路径操作：
        from pathlib import Path
        
        ## 读取文件
        content = Path(&apos;file.txt&apos;).read_text()
        
        ## 写入文件
        Path(&apos;file.txt&apos;).write_text(&apos;Hello, World!&apos;)
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_pathlib.md&quot;}
    ),
    Document(
        page_content=&quot;&quot;&quot;
        Python异常处理指南
        
        文件操作中的异常处理：
        try:
            with open(&apos;file.txt&apos;, &apos;r&apos;) as f:
                content = f.read()
        except FileNotFoundError:
            print(&quot;文件不存在&quot;)
        except PermissionError:
            print(&quot;没有权限&quot;)
        &quot;&quot;&quot;,
        metadata={&quot;source&quot;: &quot;python_exception_handling.md&quot;}
    )
]

vectorstore.add_documents(documents)

## 创建多查询检索器
multi_query_retriever = MultiQueryRetriever(vectorstore, llm)

## 检索测试
question = &quot;Python如何读取文件？&quot;
results = multi_query_retriever.retrieve(question, k=3)

print(&quot;\n最终检索结果:&quot;)
for i, doc in enumerate(results):
    print(f&quot;\n{i+1}. {doc.page_content[:100]}...&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;运行结果：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.03it/s]
📝 生成了 4 个查询变体
✅ 最终检索到 3 个文档

最终检索结果:

1. 
        Python文件操作完整指南
        
        使用open()函数读取文件：
        with open(&apos;file.txt&apos;, &apos;r&apos;) as f:
            content = f.read(...

2. 
        Python I/O操作最佳实践
        
        文件处理建议：
        1. 总是使用with语句确保文件正确关闭
        2. 处理文件编码问题...

3. 
        Python pathlib模块使用
        
        pathlib提供面向对象的文件路径操作：
        from pathlib import Path
        
        ## 读取文件...
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;3.5 技术优势与适用场景&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;优势：&lt;/strong&gt; ✅&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;提高召回率&lt;/strong&gt;：多个查询变体覆盖更多相关文档&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;增强鲁棒性&lt;/strong&gt;：对查询表述的变化不敏感&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;语义多样性&lt;/strong&gt;：涵盖不同角度和表达方式&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;易于集成&lt;/strong&gt;：可与现有检索系统无缝集成&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;适用场景：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;复杂查询&lt;/strong&gt;：需要多角度理解的复杂问题&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;专业领域&lt;/strong&gt;：技术、学术等专业内容检索&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;多语言检索&lt;/strong&gt;：支持不同语言表达方式&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;容错系统&lt;/strong&gt;：需要高召回率的应用场景&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;性能考虑：&lt;/strong&gt; ⚠️&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;计算开销&lt;/strong&gt;：需要执行多次检索操作&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;LLM依赖&lt;/strong&gt;：查询生成依赖语言模型质量&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;结果去重&lt;/strong&gt;：需要有效的融合和去重策略&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;多查询检索技术通过生成多样化的查询变体，显著提高了检索系统的召回率和鲁棒性，是构建高质量智能检索系统的重要技术之一。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 4: 查询扩展 - Query Expansion&lt;/h2&gt;
&lt;h3&gt;4.1 核心概念&lt;/h3&gt;
&lt;p&gt;查询扩展是信息检索中的一种关键技术，旨在通过向原始查询中添加相关术语、同义词或上下文信息来增强查询的语义表示，从而提高检索系统的召回率。其核心思想是弥补用户查询与文档集合中相关文档之间可能存在的词汇不匹配问题。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;主要方法：&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;基于同义词词典&lt;/strong&gt;：使用预定义的词典（如 WordNet）添加同义词。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;基于词嵌入&lt;/strong&gt;：利用词向量模型（如 Word2Vec, GloVe）找到语义相近的词汇。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;基于查询日志&lt;/strong&gt;：分析历史查询数据，找到经常一起出现的查询词。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;伪相关反馈&lt;/strong&gt;：假设初次检索返回的顶部文档是相关的，并从中提取扩展词。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;基于LLM的扩展&lt;/strong&gt;：利用大语言模型强大的语义理解和生成能力，根据查询的意图和上下文生成相关的扩展术语。&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;本部分将重点介绍基于LLM的查询扩展和伪相关反馈两种高级策略。&lt;/p&gt;
&lt;h3&gt;4.2 实现基于LLM的查询扩展&lt;/h3&gt;
&lt;p&gt;LLM能够深入理解查询的意图和上下文，生成高质量、语义相关的扩展词，而不仅仅是机械地添加同义词。&lt;/p&gt;
&lt;p&gt;以下是优化后的 &lt;code&gt;QueryExpander&lt;/code&gt; 类，它更加健壮，并提供了更好的提示词和输出处理。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from vllm import SamplingParams

class QueryExpander:
    &quot;&quot;&quot;查询扩展器（适配本地模型）&quot;&quot;&quot;
    
    def __init__(self, llm):
        self.llm = llm
    
    def expand(self, query: str) -&gt; str:
        &quot;&quot;&quot;扩展查询&quot;&quot;&quot;
        ## 使用ChatML格式
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个查询扩展专家。请为用户查询添加同义词和相关术语。
            
            请使用以下格式：
            原始术语 OR 同义词1 OR 同义词2&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            查询：{query}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            扩展查询：&quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            temperature=0.3,
            top_p=0.9,
            max_tokens=512,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        expanded = outputs[0].outputs[0].text.strip() if outputs and outputs[0].outputs else query
        
        print(f&quot;原始查询: {query}&quot;)
        print(f&quot;扩展查询: {expanded}&quot;)
        
        return expanded
## 创建查询扩展器
expander = QueryExpander(llm)
## 测试扩展
original_query = &quot;Python机器学习&quot;
expanded_query = expander.expand(original_query)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  5.79it/s]
原始查询: Python机器学习
扩展查询: Python数据分析 OR Python统计分析 OR Python数据挖掘
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;4.3 实现伪相关反馈&lt;/h3&gt;
&lt;p&gt;伪相关反馈是一种更高级的查询扩展技术，它利用初次检索的结果来指导查询扩展，特别适合文档库特定的术语和上下文。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from langchain.vectorstores import Chroma
from langchain.schema import Document
from vllm import SamplingParams

class PseudoRelevanceFeedback:
    &quot;&quot;&quot;伪相关反馈查询扩展&quot;&quot;&quot;
    
    def __init__(self, vectorstore, llm):
        self.vectorstore = vectorstore
        self.llm = llm
    
    def expand_query(self, query: str, top_k: int = 3) -&gt; str:
        &quot;&quot;&quot;使用伪相关反馈扩展查询&quot;&quot;&quot;
        ## 1. 初始检索
        initial_docs = self.vectorstore.similarity_search(query, k=top_k)
        
        ## 2. 从top文档提取关键词
        context = &quot;\n\n&quot;.join([doc.page_content for doc in initial_docs])
        
        ## 3. 使用LLM生成扩展查询
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个查询扩展助手。基于相关文档内容，为原始查询添加重要的相关术语。
            
            请保持简洁，只添加最重要的术语。&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            原始查询：{query}
            
            相关文档：
            {context}
            
            请生成扩展查询：&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            扩展查询：&quot;&quot;&quot;
        
        sampling_params = SamplingParams(
            temperature=0.2,
            top_p=0.9,
            max_tokens=512,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        expanded = outputs[0].outputs[0].text.strip() if outputs and outputs[0].outputs else query
        
        print(f&quot;📝 原始查询: {query}&quot;)
        print(f&quot;✨ 扩展查询: {expanded}&quot;)
        
        return expanded
    
    def retrieve_with_expansion(self, query: str, k: int = 5):
        &quot;&quot;&quot;使用扩展查询检索&quot;&quot;&quot;
        ## 1. 扩展查询
        expanded_query = self.expand_query(query)
        
        ## 2. 使用扩展查询检索
        results = self.vectorstore.similarity_search(expanded_query, k=k)
        
        print(f&quot;✅ 检索到 {len(results)} 个文档&quot;)
        return results
## 4. 创建伪相关反馈检索器
prf = PseudoRelevanceFeedback(vectorstore, llm)
## 5. 检索测试
results = prf.retrieve_with_expansion(&quot;Python异常&quot;, k=3)

print(&quot;\n最终检索结果:&quot;)
for i, doc in enumerate(results):
    print(f&quot;{i+1}. {doc.page_content[:100]}...&quot;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.99s/it]
📝 原始查询: Python异常
✨ 扩展查询: Python异常处理

  1. Python异常处理完整指南：https://realpython.com/python-exceptions/
  2. 使用try-except捕获异常：https://realpython.com/python-try-except/
  3. 可以捕获特定异常类型，也可以使用finally子句：https://realpython.com/python-finally/
  4. Python错误和异常类型：https://docs.python.org/3/library/exceptions.html
  5. 如何在Python中优雅地处理错误：https://realpython.com/python-error-handling/

注意：在处理异常时，应只捕获你能处理的异常，使用具体的异常类型而不是Exception，提供有用的错误信息，适当时使用finally清理资源。
✅ 检索到 3 个文档

最终检索结果:
1. 
        如何在Python中优雅地处理错误
        
        最佳实践：
        1. 只捕获你能处理的异常
        2. 使用具体的异常类型而不是Excep...
2. 
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation(...
3. 
        Python错误和异常类型
        
        Python有多种内置异常类型：
        - ValueError: 值错误
        - TypeErr...
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这种查询扩展方法能够显著提高检索系统的召回率，特别是在处理专业领域查询时，能够更好地理解用户的真实信息需求。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;Part 5: 混合检索 - Hybrid Search&lt;/h2&gt;
&lt;h3&gt;5.1 核心概念&lt;/h3&gt;
&lt;p&gt;混合检索结合向量检索(语义搜索)和关键词检索(如BM25)，利用两者的优势。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;向量检索 vs 关键词检索&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;## 向量检索（语义搜索）&lt;/strong&gt;
&lt;strong&gt;优点:&lt;/strong&gt;
✅ 理解语义，能匹配同义词
✅ 能处理模糊查询&lt;br&gt;
✅ 跨语言能力（多语言模型）&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;缺点:&lt;/strong&gt;
❌ 对精确术语匹配不敏感
❌ 对罕见词或专有名词效果差
❌ 计算成本高&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;## 关键词检索（BM25）&lt;/strong&gt;
&lt;strong&gt;优点:&lt;/strong&gt;
✅ 精确匹配关键词
✅ 对专有名词、代码等效果好
✅ 计算快速&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;缺点:&lt;/strong&gt;
❌ 不理解语义
❌ 无法匹配同义词
❌ 对查询措辞敏感&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;## 混合检索 = 向量检索 + 关键词检索 🎯&lt;/strong&gt;
→ 结合两者优势
→ 适用于大多数场景&lt;/p&gt;
&lt;h3&gt;5.2 实现向量+BM25混合检索&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;from rank_bm25 import BM25Okapi
import jieba  ## 中文分词
import numpy as np
from collections import defaultdict

class HybridRetriever:
    &quot;&quot;&quot;混合检索器（向量 + BM25）&quot;&quot;&quot;
    
    def __init__(self, embeddings, documents: List[Document], weights: tuple = (0.5, 0.5)):
        &quot;&quot;&quot;
        Args:
            weights: (向量权重, BM25权重)，两者之和应为1.0
        &quot;&quot;&quot;
        self.embeddings = embeddings
        self.documents = documents
        self.vector_weight, self.bm25_weight = weights
        
        ## 1. 向量存储
        self.vectorstore = Chroma(
            collection_name=&quot;hybrid_search&quot;,
            embedding_function=embeddings,
            persist_directory=&quot;./chroma_db&quot;,
        )
        
        ## 2. BM25索引
        self._build_bm25_index()
    
    def _build_bm25_index(self):
        &quot;&quot;&quot;构建BM25索引&quot;&quot;&quot;
        ## 分词（中文使用jieba，英文可以使用split）
        tokenized_docs = [
            list(jieba.cut(doc.page_content)) for doc in self.documents
        ]
        
        self.bm25 = BM25Okapi(tokenized_docs)
        print(f&quot;✅ BM25索引构建完成，共 {len(self.documents)} 个文档&quot;)
    
    def _vector_search(self, query: str, k: int) -&gt; List[tuple]:
        &quot;&quot;&quot;向量检索，返回 (doc, score)&quot;&quot;&quot;
        try:
            results = self.vectorstore.similarity_search_with_score(query, k=k)
            
            ## 归一化分数到[0, 1]
            if results:
                scores = [score for _, score in results]
                max_score = max(scores)
                min_score = min(scores)
                score_range = max_score - min_score if max_score != min_score else 1
                
                normalized = [
                    (doc, 1 - (score - min_score) / score_range)  ## 距离转相似度
                    for doc, score in results
                ]
                return normalized
            else:
                return []
        except Exception as e:
            print(f&quot;向量检索错误: {e}&quot;)
            return []
    
    def _bm25_search(self, query: str, k: int) -&gt; List[tuple]:
        &quot;&quot;&quot;BM25检索，返回 (doc, score)&quot;&quot;&quot;
        try:
            ## 查询分词
            tokenized_query = list(jieba.cut(query))
            
            ## BM25评分
            scores = self.bm25.get_scores(tokenized_query)
            
            ## 获取top-k
            if len(scores) &gt; 0:
                top_indices = np.argsort(scores)[::-1][:k]
                
                ## 归一化分数
                max_score = max(scores) if max(scores) &gt; 0 else 1
                
                results = [
                    (self.documents[i], scores[i] / max_score)
                    for i in top_indices if scores[i] &gt; 0
                ]
                return results
            else:
                return []
        except Exception as e:
            print(f&quot;BM25检索错误: {e}&quot;)
            return []
    
    def hybrid_search(self, query: str, k: int = 5) -&gt; List[Dict]:
        &quot;&quot;&quot;混合检索&quot;&quot;&quot;
        print(f&quot;🔍 混合检索: {query}&quot;)
        print(f&quot;   权重: 向量={self.vector_weight}, BM25={self.bm25_weight}&quot;)
        
        ## 1. 向量检索
        vector_results = self._vector_search(query, k=k*2)
        print(f&quot;📊 向量检索结果: {len(vector_results)} 个文档&quot;)
        
        ## 2. BM25检索
        bm25_results = self._bm25_search(query, k=k*2)
        print(f&quot;🔤 BM25检索结果: {len(bm25_results)} 个文档&quot;)
        
        ## 3. 合并分数
        combined_scores = defaultdict(float)
        doc_map = {}
        
        ## 处理向量检索结果
        for doc, score in vector_results:
            doc_id = id(doc)
            combined_scores[doc_id] += self.vector_weight * score
            doc_map[doc_id] = doc
        
        ## 处理BM25检索结果
        for doc, score in bm25_results:
            doc_id = id(doc)
            combined_scores[doc_id] += self.bm25_weight * score
            doc_map[doc_id] = doc
        
        ## 4. 排序并取top-k
        sorted_results = sorted(
            combined_scores.items(),
            key=lambda x: x[1],
            reverse=True
        )[:k]
        
        print(f&quot;🎯 最终合并结果: {len(sorted_results)} 个文档&quot;)
        
        ## 5. 构建结果
        final_results = []
        for doc_id, score in sorted_results:
            final_results.append({
                &apos;document&apos;: doc_map[doc_id],
                &apos;hybrid_score&apos;: score,
                &apos;content&apos;: doc_map[doc_id].page_content[:200] + &quot;...&quot; 
                if len(doc_map[doc_id].page_content) &gt; 200 else doc_map[doc_id].page_content
            })
        
        return final_results
    
    def search_comparison(self, query: str, k: int = 3):
        &quot;&quot;&quot;对比三种检索方式的效果&quot;&quot;&quot;
        print(f&quot;\n🔬 检索方式对比: {query}&quot;)
        print(&quot;=&quot; * 60)
        
        ## 向量检索
        vector_results = self._vector_search(query, k=k)
        print(&quot;\n📊 向量检索结果:&quot;)
        for i, (doc, score) in enumerate(vector_results):
            print(f&quot;  {i+1}. [分数: {score:.4f}] {doc.page_content[:80]}...&quot;)
        
        ## BM25检索
        bm25_results = self._bm25_search(query, k=k)
        print(&quot;\n🔤 BM25检索结果:&quot;)
        for i, (doc, score) in enumerate(bm25_results):
            print(f&quot;  {i+1}. [分数: {score:.4f}] {doc.page_content[:80]}...&quot;)
        
        ## 混合检索
        hybrid_results = self.hybrid_search(query, k=k)
        print(&quot;\n🎯 混合检索结果:&quot;)
        for i, result in enumerate(hybrid_results):
            print(f&quot;  {i+1}. [混合分数: {result[&apos;hybrid_score&apos;]:.4f}] {result[&apos;document&apos;].page_content[:80]}...&quot;)
        
        return {
            &apos;vector&apos;: vector_results,
            &apos;bm25&apos;: bm25_results,
            &apos;hybrid&apos;: hybrid_results
        }

## 创建混合检索器
hybrid_retriever = HybridRetriever(
    embeddings=embeddings,
    documents=documents,
    weights=(0.6, 0.4)  ## 60%向量，40% BM25
)

## 测试混合检索
print(&quot;=== 混合检索测试 ===&quot;)
results = hybrid_retriever.hybrid_search(&quot;Python异常处理&quot;, k=3)

print(&quot;\n混合检索结果:&quot;)
for i, item in enumerate(results):
    print(f&quot;\n{i+1}. [混合分数: {item[&apos;hybrid_score&apos;]:.4f}]&quot;)
    print(f&quot;   {item[&apos;content&apos;]}&quot;)

&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;
🔍 混合检索: Python异常处理
   权重: 向量=0.6, BM25=0.4

📊 向量检索...
🔤 BM25检索...

🔀 合并结果...
混合检索结果:

1. [混合分数: 0.6000]
   
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print...

2. [混合分数: 0.6000]
   
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print...

3. [混合分数: 0.6000]
   
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation()
        except Exception as e:
            print...
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;混合检索通过结合语义理解和关键词匹配，能够在各种查询场景下提供更稳定、更准确的检索结果。自适应权重调整和LLM重排序进一步提升了检索系统的智能性和实用性。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;综合实战：完整的高级检索系统&lt;/h2&gt;
&lt;p&gt;将所有技术整合成一个系统&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;class AdvancedRetrievalSystem:
    &quot;&quot;&quot;高级检索系统（集成所有技术）&quot;&quot;&quot;
    
    def __init__(self, embeddings, llm, documents: List[Document]):
        self.embeddings = embeddings
        self.llm = llm
        self.documents = documents
        
        ## 向量存储
        self.vectorstore = Chroma(
            collection_name=&quot;multi_query&quot;,
            embedding_function=embeddings,
            persist_directory=&quot;./chroma_db&quot;,
        )
        ## self.vectorstore.add_documents(documents)
        
        ## 组件
        self.multi_query = MultiQueryRetriever(self.vectorstore, llm)
        self.reranker = LocalCrossEncoderReranker(embeddings)
        self.reranker.add_documents(documents)
        self.hybrid = HybridRetriever(embeddings, documents, weights=(0.6, 0.4))
        self.rrf = ReciprocalRankFusion(k=60)
    
    def retrieve(
        self,
        query: str,
        mode: str = &quot;hybrid_multiquery_rerank&quot;,
        k: int = 5
    ):
        &quot;&quot;&quot;
        高级检索
        
        Args:
            mode: 检索模式
                - &quot;simple&quot;: 简单向量检索
                - &quot;hybrid&quot;: 混合检索
                - &quot;multiquery&quot;: 多查询检索
                - &quot;hybrid_multiquery&quot;: 混合+多查询
                - &quot;hybrid_multiquery_rerank&quot;: 混合+多查询+重排序（最强）
        &quot;&quot;&quot;
        print(f&quot;🎯 检索模式: {mode}&quot;)
        print(f&quot;❓ 查询: {query}\n&quot;)
        
        if mode == &quot;simple&quot;:
            ## 简单向量检索
            results = self.vectorstore.similarity_search(query, k=k)
            results = [{&apos;document&apos;: doc, &apos;score&apos;: 0} for doc in results]
        
        elif mode == &quot;hybrid&quot;:
            ## 混合检索
            results = self.hybrid.hybrid_search(query, k=k)
        
        elif mode == &quot;multiquery&quot;:
            ## 多查询检索
            results = self.multi_query.retrieve(query, k=k)
        
        elif mode == &quot;hybrid_multiquery&quot;:
            ## 混合 + 多查询
            ## 1. 生成查询变体
            queries = self.multi_query.generate_queries(query)
            
            ## 2. 对每个查询进行混合检索
            all_results = []
            for q in queries:
                hybrid_results = self.hybrid.hybrid_search(q, k=10)
                all_results.append([item[&apos;document&apos;] for item in hybrid_results])
            
            ## 3. RRF融合
            results = self.rrf.fuse(all_results)[:k]
        
        elif mode == &quot;hybrid_multiquery_rerank&quot;:
            ## 混合 + 多查询 + 重排序（最强模式）
            ## 1. 生成查询变体
            queries = self.multi_query.generate_queries(query)
            
            ## 2. 混合检索
            all_results = []
            for q in queries:
                hybrid_results = self.hybrid.hybrid_search(q, k=10)
                all_results.append([item[&apos;document&apos;] for item in hybrid_results])
            
            ## 3. RRF融合
            fused = self.rrf.fuse(all_results)
            candidate_docs = [item[&apos;document&apos;] for item in fused[:20]]
            
            ## 4. 交叉编码器重排序
            print(&quot;\n🎯 重排序...&quot;)
            results = self.reranker.retrieve_and_rerank(
                query=query,
                initial_k=len(candidate_docs),
                final_k=k
            )
        
        return results
    

    def query(self, question: str, mode: str = &quot;hybrid_multiquery_rerank&quot;):
        &quot;&quot;&quot;执行完整的RAG查询（修复重复代码问题）&quot;&quot;&quot;
        ## 1. 检索
        results = self.retrieve(question, mode=mode, k=3)
    
        ## 2. 提取文档：兼容两种格式
        if results and isinstance(results[0], dict) and &apos;document&apos; in results[0]:
            ## 格式: [{&quot;document&quot;: Document, &quot;score&quot;: ...}, ...]
            docs = [item[&apos;document&apos;] for item in results]
        else:
            ## 格式: [Document, Document, ...] 或其他格式
            docs = results  ## 直接使用结果
        
        ## 3. 生成答案（使用ChatML格式）
        context = &quot;\n\n&quot;.join([doc.page_content for doc in docs])
        
        ## 构建ChatML格式的提示
        prompt = f&quot;&quot;&quot;&amp;#x3C;|im_start|&gt;system
            你是一个智能问答助手。请基于提供的文档内容回答问题。如果文档中没有相关信息，请说明。
            
            文档内容：
            {context}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;user
            问题：{question}&amp;#x3C;|im_end|&gt;
            &amp;#x3C;|im_start|&gt;assistant
            &quot;&quot;&quot;
        
        ## 使用本地模型生成答案
        from vllm import SamplingParams
        sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            max_tokens=512,
            stop=[&quot;&amp;#x3C;|im_end|&gt;&quot;]
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        answer = outputs[0].outputs[0].text.strip() if outputs and outputs[0].outputs else &quot;未能生成答案&quot;
        
        return {
            &quot;question&quot;: question,
            &quot;documents&quot;: docs,
            &quot;answer&quot;: answer
        }
## 使用高级检索系统
advanced_system = AdvancedRetrievalSystem(embeddings, llm, documents)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出内容：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;📥 加载交叉编码器模型: ./Models/ms-marco-MiniLM-L-6-v2/cross-encoder/ms-marco-MiniLM-L6-v2
🗑️ 已删除旧的 Chroma 集合 &apos;local_rerank&apos;
✅ BM25索引构建完成
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 测试不同模式
modes = [&quot;simple&quot;, &quot;hybrid&quot;, &quot;hybrid_multiquery_rerank&quot;, &quot;multiquery&quot;]

question = &quot;Python中如何处理异常？&quot;

for mode in modes:
    print(f&quot;\n{&apos;=&apos;*60}&quot;)
    print(f&quot;测试模式: {mode}&quot;)
    print(f&quot;{&apos;=&apos;*60}\n&quot;)
    
    result = advanced_system.query(question, mode=mode)
    
    print(f&quot;\n💡 答案:\n{result[&apos;answer&apos;]}\n&quot;)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;输出内容如下：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;============================================================
测试模式: simple
============================================================

🎯 检索模式: simple
❓ 查询: Python中如何处理异常？

Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.40s/it]

💡 答案:
在Python中，可以使用try-except语句来捕获和处理异常。try语句块中包含可能会抛出异常的代码，如果try语句块中的代码抛出了异常，那么程序会立即跳转到与该异常匹配的except语句块中。except语句块中的代码会在异常发生时被执行，可以用来处理异常，例如打印错误信息，或者执行其他恢复操作。此外，还可以使用finally语句块来确保即使在发生异常时，某些代码也会被执行。


============================================================
测试模式: hybrid
============================================================

🎯 检索模式: hybrid
❓ 查询: Python中如何处理异常？

🔍 混合检索: Python中如何处理异常？
   权重: 向量=0.6, BM25=0.4

📊 向量检索...
🔤 BM25检索...

🔀 合并结果...
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.52s/it]

💡 答案:
在Python中，可以使用try-except语句来捕获和处理异常。try块中包含可能会引发异常的代码，如果try块中的代码引发异常，程序会立即跳转到与该异常匹配的except块中。except块中的代码会在发生异常时被执行，可以用来处理异常，例如打印错误信息、记录日志、恢复程序状态等。如果except块中没有匹配的异常，程序会继续执行。此外，还可以使用finally块来确保在try-except语句块中的代码无论如何都会被执行，无论是否发生异常。


============================================================
测试模式: hybrid_multiquery_rerank
============================================================

🎯 检索模式: hybrid_multiquery_rerank
❓ 查询: Python中如何处理异常？

Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.24it/s]
📝 生成了 4 个查询变体
🔍 混合检索: Python中如何处理异常？
   权重: 向量=0.6, BM25=0.4

📊 向量检索...
🔤 BM25检索...

🔀 合并结果...
🔍 混合检索: 1. 如何在Python中处理异常？
   权重: 向量=0.6, BM25=0.4

📊 向量检索...
🔤 BM25检索...

🔀 合并结果...
🔍 混合检索: 2. 如何应对Python中的异常情况？
   权重: 向量=0.6, BM25=0.4

📊 向量检索...
🔤 BM25检索...

🔀 合并结果...
🔍 混合检索: 3. 如何处理Python程序中的错误？
   权重: 向量=0.6, BM25=0.4

📊 向量检索...
🔤 BM25检索...

🔀 合并结果...

🎯 重排序...
🎯 使用交叉编码器重新评分...

重排序结果:
1. [得分: 8.3246] 
        如何在Python中优雅地处理错误
        
        最佳实践：
        1. 只捕获你能处理的异常
        2. 使用具体的异常类型而不是Excep...
2. [得分: 7.6097] 
        Python异常处理完整指南
        
        使用try-except捕获异常：
        try:
            risky_operation(...
3. [得分: 7.5047] 
        Python错误和异常类型
        
        Python有多种内置异常类型：
        - ValueError: 值错误
        - TypeErr...
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.14s/it]

💡 答案:
在Python中，可以使用try-except语句来优雅地处理异常。try块中包含可能会抛出异常的代码，如果try块中的代码抛出异常，程序会立即跳转到与之匹配的except块中。except块中的代码会在异常发生时被执行，可以用来处理异常或捕获特定类型的异常。此外，还可以使用finally子句来清理资源，无论是否发生异常。


============================================================
测试模式: multiquery
============================================================

🎯 检索模式: multiquery
❓ 查询: Python中如何处理异常？

Processed prompts: 100%|██████████| 1/1 [00:00&amp;#x3C;00:00,  2.24it/s]
📝 生成了 4 个查询变体
✅ 最终检索到 2 个文档
Processed prompts: 100%|██████████| 1/1 [00:01&amp;#x3C;00:00,  1.22s/it]

💡 答案:
在Python中，可以使用try-except语句来捕获和处理异常。try块中包含可能会引发异常的代码，如果try块中的代码引发了异常，程序会立即跳转到与之匹配的except块中。except块中的代码会在异常发生时被执行，可以用来处理异常或提供有用的错误信息。此外，还可以使用finally子句来确保在try-except语句块中的代码无论如何都会被执行，即使没有引发异常。
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;实验总结&lt;/h2&gt;
&lt;p&gt;重排序: 使用交叉编码器提升精度
RRF融合: 简单有效的结果融合方法
多查询: 提高召回率
混合检索: 结合向量和关键词检索&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;## 技术选择指南

RETRIEVAL_STRATEGIES = {
    &quot;快速原型&quot;: {
        &quot;策略&quot;: &quot;simple&quot;,
        &quot;说明&quot;: &quot;简单向量检索&quot;,
        &quot;适用&quot;: &quot;快速验证想法，数据量小&quot;
    },
    
    &quot;生产环境基础&quot;: {
        &quot;策略&quot;: &quot;hybrid&quot;,
        &quot;说明&quot;: &quot;混合检索（向量+BM25）&quot;,
        &quot;适用&quot;: &quot;大多数生产场景，平衡速度和质量&quot;
    },
    
    &quot;高召回率&quot;: {
        &quot;策略&quot;: &quot;multiquery&quot;,
        &quot;说明&quot;: &quot;多查询检索&quot;,
        &quot;适用&quot;: &quot;需要全面覆盖，不要遗漏相关文档&quot;
    },
    
    &quot;高精度&quot;: {
        &quot;策略&quot;: &quot;hybrid_multiquery_rerank&quot;,
        &quot;说明&quot;: &quot;混合+多查询+重排序&quot;,
        &quot;适用&quot;: &quot;对质量要求极高，可以牺牲速度&quot;
    },
    
    &quot;实时应用&quot;: {
        &quot;策略&quot;: &quot;hybrid + 缓存&quot;,
        &quot;说明&quot;: &quot;混合检索+结果缓存&quot;,
        &quot;适用&quot;: &quot;需要快速响应的应用&quot;
    }
}

def choose_strategy(priority: str):
    &quot;&quot;&quot;根据优先级选择策略&quot;&quot;&quot;
    strategy = RETRIEVAL_STRATEGIES.get(priority)
    
    if strategy:
        print(f&quot;推荐策略: {strategy[&apos;策略&apos;]}&quot;)
        print(f&quot;说明: {strategy[&apos;说明&apos;]}&quot;)
        print(f&quot;适用场景: {strategy[&apos;适用&apos;]}&quot;)
    
    return strategy
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;技术组合建议&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;检索质量层级:
Level 1: 向量检索 (基础)
  ↓
Level 2: 混合检索 (向量 + BM25)
  ↓
Level 3: 混合检索 + RRF融合
  ↓
Level 4: 多查询 + 混合检索 + RRF
  ↓
Level 5: 多查询 + 混合检索 + RRF + 重排序 (最强)

根据应用场景选择合适的级别:
- 快速原型: Level 1
- 生产基础: Level 2-3
- 高质量应用: Level 4-5
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/rag_blogs/abstract.png?origWidth=1280&amp;origHeight=720&amp;origFormat=png"/></item><item><title>常用命令和工具</title><link>https://astro-pure.js.org/blog/tool_blogs/widely_tools</link><guid isPermaLink="true">https://astro-pure.js.org/blog/tool_blogs/widely_tools</guid><description>记录常用的命令和工具。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;h2&gt;远程服务器设置本地proxy转发&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;首先开启proxy，端口号默认7890，然后在本地命令行中输入ipconfig获得ip地址&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F1.png%3ForigWidth%3D1008%26origHeight%3D359%26origFormat%3Dpng&amp;#x26;w=1008&amp;#x26;h=359&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;打开远程服务器，输入下列代码设置proxy转发端口&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;export http_proxy=http://172.21.***.109:7890
export https_proxy=http://172.21.***.109:7890
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;设置完毕后测试一下，访问网站&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;curl -v http://www.google.com
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;访问成功&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F2.png%3ForigWidth%3D805%26origHeight%3D283%26origFormat%3Dpng&amp;#x26;w=805&amp;#x26;h=283&amp;#x26;f=webp&quot; alt=&quot;&quot;&gt;&lt;/p&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;取消&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;unset http_proxy
unset https_proxy
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;本地免密登录远程服务器&lt;/h2&gt;
&lt;h3&gt;一、生成SSH密钥对&lt;/h3&gt;
&lt;p&gt;在Windows命令行（CMD）或PowerShell中执行：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cmd&quot;&gt;ssh-keygen -t ed25519 -C &quot;your_email@example.com&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;说明：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;-t ed25519&lt;/code&gt;：使用ED25519算法（安全且高效）&lt;/li&gt;
&lt;li&gt;&lt;code&gt;-C&lt;/code&gt;：添加注释（通常是你的邮箱）&lt;/li&gt;
&lt;li&gt;直接按回车使用默认路径（C:\Users\你的用户名.ssh\id_ed25519）&lt;/li&gt;
&lt;li&gt;密码短语直接按回车跳过（实现免密）&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;二、复制公钥到远程服务器&lt;/h3&gt;
&lt;p&gt;&lt;strong&gt;重要说明&lt;/strong&gt;：如果服务器上已有 &lt;code&gt;authorized_keys&lt;/code&gt; 文件（比如之前在其他设备上配置过），以下命令会自动&lt;strong&gt;追加&lt;/strong&gt;新公钥，不会覆盖已有的公钥。这样你可以在多台设备上实现免密登录。&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;方法1：使用PowerShell命令（推荐，最简单）&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;在PowerShell中执行：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;type $env:USERPROFILE\.ssh\id_ed25519.pub | ssh username@remote_ip &quot;mkdir -p ~/.ssh &amp;#x26;&amp;#x26; cat &gt;&gt; ~/.ssh/authorized_keys &amp;#x26;&amp;#x26; chmod 700 ~/.ssh &amp;#x26;&amp;#x26; chmod 600 ~/.ssh/authorized_keys&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;示例：&lt;/strong&gt;&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;type $env:USERPROFILE\.ssh\id_ed25519.pub | ssh root@192.168.1.100 &quot;mkdir -p ~/.ssh &amp;#x26;&amp;#x26; cat &gt;&gt; ~/.ssh/authorized_keys &amp;#x26;&amp;#x26; chmod 700 ~/.ssh &amp;#x26;&amp;#x26; chmod 600 ~/.ssh/authorized_keys&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;strong&gt;方法2：手动复制&lt;/strong&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;查看本地公钥（CMD）：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cmd&quot;&gt;type %USERPROFILE%\.ssh\id_ed25519.pub
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;或（PowerShell）：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;type $env:USERPROFILE\.ssh\id_ed25519.pub
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;复制公钥内容，登录远程服务器：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-cmd&quot;&gt;ssh username@remote_ip
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;在远程服务器上执行：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;mkdir -p ~/.ssh
echo &quot;你的公钥内容&quot; &gt;&gt; ~/.ssh/authorized_keys
chmod 700 ~/.ssh
chmod 600 ~/.ssh/authorized_keys
&lt;/code&gt;&lt;/pre&gt;
&lt;h3&gt;三、测试免密登录&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-cmd&quot;&gt;ssh username@remote_ip
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;如果不需要输入密码直接登录，说明配置成功！&lt;/p&gt;
&lt;h3&gt;四、配置快捷别名（可选）&lt;/h3&gt;
&lt;p&gt;编辑 SSH 配置文件 &lt;code&gt;C:\Users\你的用户名\.ssh\config&lt;/code&gt;：&lt;/p&gt;
&lt;p&gt;使用记事本打开：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-cmd&quot;&gt;notepad %USERPROFILE%\.ssh\config
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;添加以下内容：&lt;/p&gt;
&lt;pre&gt;&lt;code&gt;Host myserver
    HostName 192.168.1.100
    User root
    Port 22
    IdentityFile C:\Users\你的用户名\.ssh\id_ed25519
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;以后直接 &lt;code&gt;ssh myserver&lt;/code&gt; 即可登录。&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;结束GPU进程&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;ps x |grep python|awk &apos;{print $1}&apos;|xargs kill
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;解释：ps grep组合查看python进程，使用awk分割pid，xargs组合kill命令，等价于kill [属于python的pid]&lt;/p&gt;
&lt;h2&gt;使用Tmux挂起会话进程&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;输入tmux即可进入Tmux窗口，但不推荐。因为第一个启动的 Tmux 窗口，编号是0，第二个窗口的编号是1，以此类推。这些窗口对应的会话，就是 0 号会话、1 号会话。使用编号区分会话，不太直观，更好的方法是为会话起名。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;使用命令创建一个指定名称的会话&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;tmux new -s &amp;#x3C;session-name&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;分离会话&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;tmux detach
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;或者使用Ctrl + b, d
先按下 Ctrl + b（这是 tmux 的前缀键）
然后松开，再按 d。&lt;/p&gt;
&lt;ol start=&quot;4&quot;&gt;
&lt;li&gt;查看所有会话&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;tmux ls
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;接入会话&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;tmux attach -t &amp;#x3C;session-name&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;6&quot;&gt;
&lt;li&gt;杀死会话&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;tmux kill-session -t &amp;#x3C;session-name&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;从远程服务器之间数据拷贝&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;使用&lt;code&gt;sudo apt-get install rsync&lt;/code&gt;来安装rsync工具，支持断点重续。&lt;/li&gt;
&lt;li&gt;执行命令&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;rsync -avzP --exclude=&apos;Multimodal-Cooperation-main/code/log_sample/&apos; -e &quot;ssh -p 2233&quot; tangqian@172.21.201.220:/path/to/source/ /path/to/destination/
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;-v：详细模式输出，显示传输过程中的信息。&lt;/p&gt;
&lt;p&gt;-z：在传输文件时进行压缩。&lt;/p&gt;
&lt;p&gt;-P: -P 参数是 --partial 和 --progress 的组合，其中：partial：告诉 rsync 保持部分传输的文件，这样如果传输中断，下次传输同一文件时不会从头开始，而是从中断处继续。progress：显示传输进度。&lt;/p&gt;
&lt;p&gt;--exclude=&apos;log_sample/&apos;：指定排除不拷贝的文件或目录，这里使用了相对路径。&lt;/p&gt;
&lt;p&gt;user@remote:/path/to/source/：远程服务器上的用户名、主机地址和源路径。&lt;/p&gt;
&lt;p&gt;/path/to/destination/：本地的目标路径。&lt;/p&gt;
&lt;p&gt;-e &quot;ssh -p 2233&quot; 指定了通过SSH连接并且使用端口2233。&lt;/p&gt;
&lt;p&gt;确保源路径和目标路径以斜杠结尾，这取决于你希望如何同步文件夹内容。在你的例子中，源路径 /home/tangqian/my_project/ 以斜杠结尾意味着复制该目录的内容到目标位置；如果不以斜杠结尾，则会创建一个名为 my_project 的子目录并将所有内容放入其中。&lt;/p&gt;
&lt;ol start=&quot;3&quot;&gt;
&lt;li&gt;也可以本地上传至服务&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;rsync -avzP --progress ./upload.zip -e &quot;ssh -p 2233&quot; tangqian@172.21.201.220:/home/tangqian/dataset
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;远程服务器与本地之间数据拷贝&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;首先去微软商店安装ubuntu系统&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;安装完之后，可能出现bug，参考&lt;a href=&quot;https://blog.csdn.net/2301_78094384/article/details/143270215?spm=1001.2101.3001.6650.3&amp;#x26;utm_medium=distribute.pc_relevant.none-task-blog-2~default~YuanLiJiHua~Position-3-143270215-blog-141499189.235%5Ev43%5Epc_blog_bottom_relevance_base5&amp;#x26;depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~YuanLiJiHua~Position-3-143270215-blog-141499189.235%5Ev43%5Epc_blog_bottom_relevance_base5&amp;#x26;utm_relevant_index=6&quot;&gt;windows11 启用 wsl, 安装 ubuntu 系统&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;在cmd终端输入ubuntu进入系统&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;进入windows路径&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;cd /mnt
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;5&quot;&gt;
&lt;li&gt;使用rsync传数据&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;Tensorflow-GPU安装&lt;/h2&gt;
&lt;p&gt;python=3.6  2.x才支持GPU加速&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;pip install keras==2.2.4 tensorflow-gpu==1.14.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;pip install keras==2.3.1 tensorflow-gpu==2.2.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;PyTorch-GPU安装&lt;/h2&gt;
&lt;p&gt;python=3.6-3.9&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;pip install torch==1.9.0 torchvision==0.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;pip install torch==1.10.0 torchtext==0.11.0 torchvision==0.11.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;适用于A100 python3.8&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;查看CUDA版本指令，通常torch.version.cuda 显示的版本号要小于等于NVIDIA CUDA 版本&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;import torch
print(torch.cuda.is_available())
print(torch.version.cuda)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;其他镜像&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;-i https://mirrors.aliyun.com/pypi/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;-i https://pypi.douban.com/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;-i https://pypi.mirrors.ustc.edu.cn/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;-i https://repo.huaweicloud.com/repository/pypi/simple
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;服务器上报数据和缓存不在同一个GPU上的错误&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;RuntimeError: module must have its parameters and buffers on device cuda:3 (device_ids[0]) but found one of them on device: cuda:0
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;这是由于在并行训练时，如果已经通过 model.to(device) 或 torch.nn.DataParallel 将模型移动到了特定的 GPU 或多个 GPU 上，再次调用 model.cuda() 是不必要的。&lt;/p&gt;
&lt;p&gt;修改后的代码如下：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;model.to(device)
if len(gpu_ids) &gt; 1:
    model = torch.nn.DataParallel(model, device_ids=gpu_ids)
else:
    model.cuda()
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;Linux常用的命令&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;删除当前文件夹下的所有文件和子目录&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;rm -rf *
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;移动指定目录下的所有文件和子目录&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;mv /path/to/source/* /path/to/destination/
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;zip压缩命令跳过指定文件夹&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;zip -r Multimodal-Cooperation-main-v2.zip Multimodal-Cooperation-main/ -x &quot;Multimodal-Cooperation-main/code/log_sample/*&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;Nvidia-smi在当前页面刷新输出&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;watch -n 0.1 nvidia-smi
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;解释
watch：这是一个 Linux 命令，用于定期执行指定的命令并在终端中显示输出。
-n 0.1：这个选项指定了刷新的间隔时间，这里设置为 0.1 秒。
nvidia-smi：这是你希望定期执行并显示其输出的命令。&lt;/p&gt;
&lt;h2&gt;Ubuntu启动EasyConnect&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-powershell&quot;&gt;sudo /usr/share/sangfor/EasyConnect/EasyConnect
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;Docker的使用&lt;/h2&gt;
&lt;p&gt;&lt;a href=&quot;https://blog.csdn.net/weixin_72137075/article/details/144824918&quot;&gt;参考博客&lt;/a&gt;&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;拉取原始镜像&lt;code&gt;docker pull ubuntu:22.04&lt;/code&gt; 此时镜像名称为 ubuntu，标签为 22.04。
或者使用 docker tag 重命名镜像​​：&lt;code&gt;docker tag ubuntu:22.04 my-ubuntu:v1&lt;/code&gt; 重命名为 my-ubuntu:v1&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;查看镜像&lt;code&gt;docker image&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F3.png%3ForigWidth%3D721%26origHeight%3D70%26origFormat%3Dpng&amp;#x26;w=721&amp;#x26;h=70&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;删除原始镜像&lt;code&gt;docker rmi ubuntu:22.04&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F4.png%3ForigWidth%3D726%26origHeight%3D101%26origFormat%3Dpng&amp;#x26;w=726&amp;#x26;h=101&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;创建容器&lt;code&gt;docker run -it --name hw_mirror ubuntu:22.04&lt;/code&gt;（&lt;strong&gt;一定要带标签&lt;/strong&gt;）
​​&lt;strong&gt;docker run&lt;/strong&gt;：Docker 的核心命令，用于创建并启动一个新容器
&lt;strong&gt;​​-it​​ (组合参数)&lt;/strong&gt;：-i (--interactive)：保持标准输入打开，允许与容器交互
-t (--tty)：分配伪终端（TTY），使容器像本地终端一样工作
→ 组合效果：进入容器的​​交互式命令行模式​​
​​**--name hw_mirror**：为容器指定名称 hw_mirror
如果不指定，Docker 会随机生成名称（如 friendly_curie）
后续可用 docker start hw_mirror 通过名称操作容器
&lt;strong&gt;​​ubuntu:22.04​​&lt;/strong&gt;：使用的镜像名称及标签，若本地不存在，会先自动执行 docker pull ubuntu:22.04&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;查看容器&lt;code&gt;docker ps -a&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F5.png%3ForigWidth%3D879%26origHeight%3D83%26origFormat%3Dpng&amp;#x26;w=879&amp;#x26;h=83&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;删除容器&lt;code&gt;docker rm hw_mirror&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F6.png%3ForigWidth%3D769%26origHeight%3D67%26origFormat%3Dpng&amp;#x26;w=769&amp;#x26;h=67&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;查看ubuntu版本&lt;code&gt;cat /etc/os-release&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F7.png%3ForigWidth%3D590%26origHeight%3D221%26origFormat%3Dpng&amp;#x26;w=590&amp;#x26;h=221&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;容器内安装python&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 1. 更新包列表
apt update

# 2. 安装 Python 3
apt install -y python3

# 3. 安装 pip（Python 包管理器）
apt install -y python3-pip

# 4. 验证安装
python3 --version
pip3 --version
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F8.png%3ForigWidth%3D478%26origHeight%3D67%26origFormat%3Dpng&amp;#x26;w=478&amp;#x26;h=67&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;自定义创建Docker&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;创建Dockerfile以及一些准备的文件
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F9.png%3ForigWidth%3D146%26origHeight%3D87%26origFormat%3Dpng&amp;#x26;w=146&amp;#x26;h=87&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;
Dockerfile的内容：&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 使用官方轻量级镜像
FROM ubuntu:22.04

# 临时禁用交互（构建结束时移除）
ENV DEBIAN_FRONTEND=noninteractive

# 安装系统依赖（清理缓存）
RUN apt-get update &amp;#x26;&amp;#x26; \
    apt-get install -y --no-install-recommends \  
        python3-pip \
        curl \
    &amp;#x26;&amp;#x26; rm -rf /var/lib/apt/lists/*
    
# 不需要安装python3.10（22.04自带python3.10）
# 删除环境变量（防止运行时影响）
RUN unset DEBIAN_FRONTEND

# 安装Python依赖（优先复制依赖文件）
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 设置工作目录并复制代码
WORKDIR /app
# 将当前路径下的所有内容拷贝到工作目录
COPY . .

# 声明端口（文档用途）
EXPOSE 8000

# 启动命令
# CMD [&quot;python3&quot;, &quot;app.py&quot;]
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;2&quot;&gt;
&lt;li&gt;
&lt;p&gt;在包含Dockerfile的目录下构建镜像&lt;code&gt;docker build -t my-custom-app:v1 .&lt;/code&gt;
&lt;strong&gt;docker build&lt;/strong&gt;
​​作用​​：启动 Docker 构建流程
​​行为​​：根据当前目录下的 Dockerfile 创建新镜像
&lt;strong&gt;-t my-custom-app:v1&lt;/strong&gt;
​​-t​​ 参数：--tag 的缩写，用于指定镜像名称和标签
​​my-custom-app​​：自定义的镜像名称（通常按&quot;项目名/用途&quot;格式）
​​:v1​​：镜像标签（推荐使用语义化版本控制）
标签省略时默认为 :latest
有效标签格式：v1.2.3, beta, 20240610等
&lt;strong&gt;.（末尾的点）
​&lt;/strong&gt;​含义​​：指定​​构建上下文​​路径
​​作用​​：Dockerfile 中的 COPY 和 ADD 指令相对此路径工作
​​注意事项​​：
Docker 会把当前目录​​所有内容​​发送给 Docker 守护进程
使用 .dockerignore 文件排除不需要的文件
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F10.png%3ForigWidth%3D985%26origHeight%3D238%26origFormat%3Dpng&amp;#x26;w=985&amp;#x26;h=238&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;使用镜像执行&lt;code&gt;python3 -V&lt;/code&gt;命令：&lt;code&gt;docker run -it my-custom-app:v1 python3 -V &lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F11.png%3ForigWidth%3D614%26origHeight%3D34%26origFormat%3Dpng&amp;#x26;w=614&amp;#x26;h=34&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;创建并使用新的容器&lt;code&gt;docker run -it -d --name my-app my-custom-app:v2 /bin/bash &lt;/code&gt;
&lt;strong&gt;docker run&lt;/strong&gt;：创建并启动新容器
&lt;strong&gt;-it&lt;/strong&gt;：​​组合参数​​：-i：保持标准输入流打开（允许交互）-t：分配伪终端（创建命令行界面）
&lt;strong&gt;-d&lt;/strong&gt;：后台运行（detach 模式）
&lt;strong&gt;--name my-app&lt;/strong&gt;：指定容器名称为 my-app
&lt;strong&gt;my-custom-app:v2&lt;/strong&gt;：使用的镜像和标签
&lt;strong&gt;/bin/bash&lt;/strong&gt;：容器启动后执行的命令（启动 Bash Shell）&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;连接已有容器：&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;# 如果容器已停止：
docker start my-app
# 进入运行中的容器
docker exec -it demo-app /bin/bash
&lt;/code&gt;&lt;/pre&gt;
&lt;ol start=&quot;6&quot;&gt;
&lt;li&gt;修改容器后，将当前容器提交为镜像：&lt;code&gt;docker commit ur5e my-custom-ur5e-image&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F12.png%3ForigWidth%3D732%26origHeight%3D85%26origFormat%3Dpng&amp;#x26;w=732&amp;#x26;h=85&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/li&gt;
&lt;li&gt;保存镜像为tar压缩文件&lt;code&gt;docker save -o ur5e-image.tar my-custom-ur5e-image&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;压缩文件大小&lt;code&gt;gzip ur5e-image.tar&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;解压tar文件&lt;code&gt;gzip -d ur5e-image.tar.gz&lt;/code&gt;&lt;/li&gt;
&lt;li&gt;加载镜像&lt;code&gt;docker load -i ur5e-image.tar&lt;/code&gt;
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F13.png%3ForigWidth%3D881%26origHeight%3D304%26origFormat%3Dpng&amp;#x26;w=881&amp;#x26;h=304&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/li&gt;
&lt;li&gt;&lt;code&gt;docker run -itd --name ros2_gpu   --gpus all   --env=&quot;DISPLAY&quot;   --env=&quot;NVIDIA_DRIVER_CAPABILITIES=all&quot;   --env=&quot;QT_X11_NO_MITSHM=1&quot;   --env=&quot;LIBGL_ALWAYS_INDIRECT=1&quot;   --volume=&quot;/tmp/.X11-unix:/tmp/.X11-unix:rw&quot;   --volume=&quot;$HOME/.Xauthority:/root/.Xauthority:rw&quot;   my-custom-ur5e-image:latest   /bin/bash&lt;/code&gt;
&lt;strong&gt;--gpus all&lt;/strong&gt;：允许容器访问宿主机的所有 NVIDIA GPU
&lt;strong&gt;--env=&quot;NVIDIA_DRIVER_CAPABILITIES=all&quot;&lt;/strong&gt;：定义 NVIDIA 驱动的功能范- all 包含 ​​图形渲染（OpenGL/Vulkan）​​ 和 ​​计算（CUDA）
&lt;strong&gt;--env=&quot;DISPLAY&quot;&lt;/strong&gt;： 传递宿主机的 DISPLAY 环境变量（通常为 :0）到容器
&lt;strong&gt;--volume=&quot;/tmp/.X11-unix:/tmp/.X11-unix:rw&quot;&lt;/strong&gt;： 挂载宿主机的 X11 套接字目录到容器
&lt;strong&gt;--volume=&quot;$HOME/.Xauthority:/root/.Xauthority:rw&quot;&lt;/strong&gt;： 挂载 X11 认证文件（~/.Xauthority）到容器
&lt;strong&gt;--env=&quot;QT_X11_NO_MITSHM=1&quot;&lt;/strong&gt;：禁用 Qt 的共享内存（MIT-SHM）扩展
&lt;strong&gt;--env=&quot;LIBGL_ALWAYS_INDIRECT=1&quot;&lt;/strong&gt;： 强制 OpenGL 使用间接渲染（通过 X11 转发）&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;Ubuntu换源工具&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;wget http://fishros.com/install -O fishros &amp;#x26;&amp;#x26; . fishros
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;Ubuntu使用代理教程&lt;/h2&gt;
&lt;p&gt;&lt;a href=&quot;https://devpn.github.io/docs/start/ubuntu/clash/&quot;&gt;教程&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;云服务器直接下载百度网盘文件&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;先下载bypy&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;pip install bypy
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;登录&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;bypy info
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;登录后会在“我的应用数据”文件夹下生成bypy文件夹，我们只需要把要下载的内容复制到该文件夹中
&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F14.png%3ForigWidth%3D245%26origHeight%3D184%26origFormat%3Dpng&amp;#x26;w=245&amp;#x26;h=184&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;然后下载文件夹到指定目录（下载文件用download）&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;bypy downdir / /path/to/your/target/directory -v
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;上传到百度云命令&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;bypy upload /path/to/local/folder
&lt;/code&gt;&lt;/pre&gt;
&lt;/li&gt;
&lt;/ol&gt;
&lt;h2&gt;清理pip的缓存&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;rm -rf /tmp/*

# 2. 删除 Python 缓存
find /root -name &quot;__pycache__&quot; -type d -exec rm -rf {} + 2&gt;/dev/null
find /root -name &quot;*.pyc&quot; -delete 2&gt;/dev/null

# 3. 清理 pip/conda 残留（即使你迁移了 .cache）
rm -rf /root/.local/lib/python*/site-packages/*  # 谨慎！如果你用 pip install --user
rm -rf /root/.conda

# 4. 删除 Jupyter 缓存
rm -rf /root/.local/share/jupyter
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;将本地代码上传到Github仓库&lt;/h2&gt;
&lt;p&gt;生成新的SSH密钥&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;ssh-keygen -t ed25519 -C &quot;14680426@qq.com&quot;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;将SSH公钥添加到Github&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;cat ~/.ssh/id_ed25519.pub
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;复制公钥内容，登录 GitHub → 点右上角头像 → Settings → 左侧 SSH and GPG keys → New SSH key&lt;/p&gt;
&lt;p&gt;测试SSH连接&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;ssh -T git@github.com
&lt;/code&gt;&lt;/pre&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;# 初始化仓库
git init

# 配置git用户名和邮箱
git config --global user.name &quot;SoupCola&quot;
git config --global user.email &quot;14680426@qq.com&quot;

# 查看改动
git status

# 添加所有改动
git add .

# 提交
git commit -m &quot;feat: support structured output tools&quot;

# 关联远程仓库
git remote add origin git@github.com:你的用户名/仓库名.git

# 推送到 main 分支（GitHub 默认分支可能是 main 或 master）
git push -u origin main
&lt;/code&gt;&lt;/pre&gt;
&lt;h2&gt;WSL2安装可视化桌面&lt;/h2&gt;
&lt;p&gt;安装VcXsrv
&lt;a href=&quot;https://sourceforge.net/projects/vcxsrv/&quot;&gt;下载地址&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;启动VcXsrv
一直点next（注意第一个界面设置Display为0），第三个界面勾选Disable access control&lt;/p&gt;
&lt;p&gt;在wsl环境中安装xfce&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sudo apt install xfce4-terminal
sudo apt install xfce4
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;因为每次启动WSL都会使用新的IP，所以配置自动更新~/.bashrc文件的IP&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;nano ～/.bashrc
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;在末尾添加&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;export DISPLAY=$(awk &apos;/nameserver/{print $2; exit}&apos; /etc/resolv.conf):0
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;启动xeyes测试&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;xeyes
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;更新配置&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;source .bashrc
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;安装中文支持&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;sudo apt install ttf-wqy-zenhei
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;运行xfce4即可查看桌面&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;startxfce4
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;&lt;img src=&quot;https://astro-pure.js.org/_image?href=%2F%40fs%2Froot%2Fmy_blog%2Fsrc%2Fcontent%2Fblog%2Ftool_blogs%2Fwidely_tools%2F15.png%3ForigWidth%3D2560%26origHeight%3D1600%26origFormat%3Dpng&amp;#x26;w=2560&amp;#x26;h=1600&amp;#x26;f=webp&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;h2&gt;Windows 关闭 Service Worker 进程&lt;/h2&gt;
&lt;p&gt;1️⃣ 打开运行：&lt;/p&gt;
&lt;p&gt;Win + R&lt;/p&gt;
&lt;p&gt;输入：&lt;/p&gt;
&lt;p&gt;resmon&lt;/p&gt;
&lt;p&gt;2️⃣ 打开 CPU&lt;/p&gt;
&lt;p&gt;3️⃣ 在 关联的句柄 搜索：&lt;/p&gt;
&lt;p&gt;Service Worker&lt;/p&gt;
&lt;p&gt;或你的文件夹名字&lt;/p&gt;
&lt;p&gt;4️⃣ 找到进程 → 结束进程。&lt;/p&gt;
&lt;h2&gt;windows清理python进程&lt;/h2&gt;
&lt;p&gt;查找进程命令&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;wmic process where &quot;name=&apos;python.exe&apos; or name=&apos;pythonw.exe&apos;&quot; get ProcessId,CommandLine /format:table
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;kill命令&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-bash&quot;&gt;taskkill /PID &amp;#x3C;PID&gt; /F
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/tool_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/tool_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>YOLO处理代码</title><link>https://astro-pure.js.org/blog/tool_blogs/yolo_tools</link><guid isPermaLink="true">https://astro-pure.js.org/blog/tool_blogs/yolo_tools</guid><description>一些YOLO常用的数据预处理代码。</description><pubDate>Fri, 30 Jan 2026 21:23:00 GMT</pubDate><content:encoded>&lt;h2&gt;YOLO数据增强代码&lt;/h2&gt;
&lt;pre&gt;&lt;code class=&quot;language-python&quot;&gt;import os
import cv2
import albumentations as A
import random
import shutil
from tqdm import tqdm

def clean_unmatched_files(img_dir, label_dir):
    &quot;&quot;&quot;
    清理YOLO数据集中不匹配的图片和标签文件：
    - 删除没有对应.txt标签的图片
    - 删除没有对应.jpg/.png图片的标签
    &quot;&quot;&quot;
    print(&quot;正在清理不匹配的图片和标签文件...&quot;)
    
    img_extensions = (&apos;.jpg&apos;, &apos;.jpeg&apos;, &apos;.png&apos;)
    img_files = {os.path.splitext(f)[0] for f in os.listdir(img_dir) if f.lower().endswith(img_extensions)}
    label_files = {os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.endswith(&apos;.txt&apos;)}

    # 找出不匹配的文件
    imgs_without_labels = img_files - label_files
    labels_without_imgs = label_files - img_files

    # 删除没有标签的图片
    for stem in imgs_without_labels:
        for ext in img_extensions:
            img_path = os.path.join(img_dir, stem + ext)
            if os.path.exists(img_path):
                os.remove(img_path)
                print(f&quot;已删除无标签图片: {img_path}&quot;)

    # 删除没有图片的标签
    for stem in labels_without_imgs:
        label_path = os.path.join(label_dir, stem + &apos;.txt&apos;)
        if os.path.exists(label_path):
            os.remove(label_path)
            print(f&quot;已删除无图片标签: {label_path}&quot;)

    print(f&quot;清理完成！共删除 {len(imgs_without_labels)} 张无标签图片 和 {len(labels_without_imgs)} 个无图片标签。&quot;)


def augment_images_and_labels(img_dir, label_dir, output_img_dir, output_label_dir, augment_times=3, view_dir=None, view_ratio=0.1):
    &quot;&quot;&quot;
    对YOLO数据进行安全的数据增强，仅针对训练集
    :param img_dir: 原始图片的目录路径
    :param label_dir: YOLO标签目录路径
    :param output_img_dir: 增强后的图片保存目录（可与输入相同）
    :param output_label_dir: 增强后的标签保存目录（可与输入相同）
    :param augment_times: 每张图片的增强次数
    :param view_dir: 查看增强效果的目录路径
    :param view_ratio: 查看增强效果的图片比例
    &quot;&quot;&quot;
    # 确保输出目录存在
    os.makedirs(output_img_dir, exist_ok=True)
    os.makedirs(output_label_dir, exist_ok=True)
    
    if view_dir:
        os.makedirs(view_dir, exist_ok=True)

    # 获取所有原始图片（排除已增强的）
    img_extensions = (&apos;.jpg&apos;, &apos;.jpeg&apos;, &apos;.png&apos;)
    img_files = [f for f in os.listdir(img_dir) 
                 if f.lower().endswith(img_extensions) and &apos;_aug_&apos; not in f]
    
    if not img_files:
        print(&quot;警告：未找到任何原始图片（可能已被清理或路径错误）&quot;)
        return

    num_view_imgs = max(1, int(len(img_files) * view_ratio))
    view_indices = set(random.sample(range(len(img_files)), num_view_imgs))

    for idx, img_file in enumerate(tqdm(img_files)):
        img_path = os.path.join(img_dir, img_file)
        label_path = os.path.join(label_dir, os.path.splitext(img_file)[0] + &apos;.txt&apos;)

        # 读取图像
        image = cv2.imread(img_path)
        if image is None:
            print(f&quot;无法读取图像，跳过: {img_path}&quot;)
            continue
        height, width = image.shape[:2]

        # 读取标签
        bboxes = []
        class_labels = []
        if os.path.exists(label_path):
            with open(label_path, &apos;r&apos;) as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) != 5:
                        continue  # 跳过格式错误的行
                    class_id = int(parts[0])
                    x_center, y_center, w, h = map(float, parts[1:])
                    # 验证坐标合法性（YOLO格式应在 [0,1]）
                    if not (0 &amp;#x3C;= x_center &amp;#x3C;= 1 and 0 &amp;#x3C;= y_center &amp;#x3C;= 1 and 0 &amp;#x3C; w &amp;#x3C;= 1 and 0 &amp;#x3C; h &amp;#x3C;= 1):
                        continue
                    bboxes.append([x_center, y_center, w, h])
                    class_labels.append(class_id)

        if not bboxes:
            print(f&quot;有效标签为空，跳过: {label_path}&quot;)
            continue

        # 动态设置裁剪尺寸（防止 crop &gt; image）
        min_dim = min(height, width)
        crop_h = min(500, min_dim)
        crop_w = min(500, min_dim)

        augmentations = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.GaussianBlur(blur_limit=(3, 7), p=0.2),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.Resize(width=640, height=640, p=0.5),
            A.RandomCrop(width=crop_w, height=crop_h, p=0.5),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.2),
            A.RandomScale(scale_limit=0.2, p=0.2),
        ], bbox_params=A.BboxParams(format=&apos;yolo&apos;, label_fields=[&apos;class_labels&apos;], min_visibility=0.1, min_area=10))

        for i in range(augment_times):
            try:
                augmented = augmentations(image=image, bboxes=bboxes, class_labels=class_labels)
            except Exception as e:
                print(f&quot;增强失败（{img_file} 第{i}次）: {e}&quot;)
                continue

            aug_image = augmented[&apos;image&apos;]
            aug_bboxes = augmented[&apos;bboxes&apos;]
            aug_labels = augmented[&apos;class_labels&apos;]

            if not aug_bboxes:
                continue  # 跳过无有效框的增强结果

            # 保存增强结果
            base_name = os.path.splitext(img_file)[0]
            out_img_path = os.path.join(output_img_dir, f&quot;{base_name}_aug_{i}.jpg&quot;)
            out_label_path = os.path.join(output_label_dir, f&quot;{base_name}_aug_{i}.txt&quot;)

            cv2.imwrite(out_img_path, aug_image)
            with open(out_label_path, &apos;w&apos;) as f:
                for bbox, cls in zip(aug_bboxes, aug_labels):
                    x_center, y_center, w, h = bbox
                    # 再次确保数值合法（防止浮点误差）
                    x_center = max(0.0, min(1.0, x_center))
                    y_center = max(0.0, min(1.0, y_center))
                    w = max(0.0, min(1.0, w))
                    h = max(0.0, min(1.0, h))
                    f.write(f&quot;{int(cls)} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}\n&quot;)

            # 可视化样本
            if view_dir and idx in view_indices:
                view_img = aug_image.copy()
                h_img, w_img = view_img.shape[:2]
                for bbox, cls in zip(aug_bboxes, aug_labels):
                    xc, yc, bw, bh = bbox
                    x1 = int((xc - bw / 2) * w_img)
                    y1 = int((yc - bh / 2) * h_img)
                    x2 = int((xc + bw / 2) * w_img)
                    y2 = int((yc + bh / 2) * h_img)
                    cv2.rectangle(view_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    cv2.putText(view_img, str(cls), (x1, max(0, y1 - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
                cv2.imwrite(os.path.join(view_dir, f&quot;{base_name}_aug_{i}.jpg&quot;), view_img)

    print(&quot;✅ 数据增强完成！&quot;)


if __name__ == &quot;__main__&quot;:
    # 配置路径（必须为全英文路径）
    img_dir = &quot;./train/images&quot;
    label_dir = &quot;./train/labels&quot;
    output_img_dir = &quot;./train/images&quot;      # 增强图可追加到原目录
    output_label_dir = &quot;./train/labels&quot;   # 增强标签同理

    # 第一步：清理不匹配文件
    clean_unmatched_files(img_dir, label_dir)

    # 第二步：执行增强
    augment_images_and_labels(
        img_dir=img_dir,
        label_dir=label_dir,
        output_img_dir=output_img_dir,
        output_label_dir=output_label_dir,
        augment_times=4,
        view_dir=&quot;view&quot;,
        view_ratio=0.1
    )
&lt;/code&gt;&lt;/pre&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/tool_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/><enclosure url="/@fs/root/my_blog/src/content/blog/tool_blogs/abstract.png?origWidth=1664&amp;origHeight=928&amp;origFormat=png"/></item><item><title>Markdown 语法支持</title><link>https://astro-pure.js.org/blog/markdown-zh</link><guid isPermaLink="true">https://astro-pure.js.org/blog/markdown-zh</guid><description>Markdown 是一种轻量级的「标记语言」。</description><pubDate>Wed, 26 Jul 2023 08:00:00 GMT</pubDate><content:encoded>&lt;h2&gt;基本语法&lt;/h2&gt;
&lt;p&gt;Markdown 是一种轻量级且易于使用的语法，用于为您的写作设计风格。&lt;/p&gt;
&lt;h3&gt;标题&lt;/h3&gt;
&lt;p&gt;文章内容较多时，可以用标题分段：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;# 标题 1

## 标题 2

## 大标题

### 小标题
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;标题预览会打乱文章的结构，所以在此不展示。&lt;/p&gt;
&lt;h3&gt;粗斜体&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;_斜体文本_

**粗体文本**

**_粗斜体文本_**
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;&lt;em&gt;斜体文本&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;粗体文本&lt;/strong&gt;&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;&lt;em&gt;粗斜体文本&lt;/em&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;h3&gt;链接&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;文字链接 [链接名称](http://链接网址)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;文字链接 &lt;a href=&quot;http://%E9%93%BE%E6%8E%A5%E7%BD%91%E5%9D%80&quot;&gt;链接名称&lt;/a&gt;&lt;/p&gt;
&lt;h3&gt;行内代码&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;这是一条 `单行代码`
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;这是一条 &lt;code&gt;行内代码&lt;/code&gt;&lt;/p&gt;
&lt;h3&gt;代码块&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;```js
// calculate fibonacci
function fibonacci(n) {
  if (n &amp;#x3C;= 1) return 1
  return fibonacci(n - 1) + fibonacci(n - 2)
}
```
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-js&quot;&gt;// calculate fibonacci
function fibonacci(n) {
  if (n &amp;#x3C;= 1) return 1
  return fibonacci(n - 1) + fibonacci(n - 2)
}
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;当前使用 shiki 作为代码高亮插件，支持的语言请参考 &lt;a href=&quot;https://shiki.matsu.io/languages.html&quot;&gt;shiki / languages&lt;/a&gt;。&lt;/p&gt;
&lt;h3&gt;行内公式&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;这是一条行内公式 $e^{i\pi} + 1 = 0$
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;这是一条行内公式 $e^{i\pi} + 1 = 0$&lt;/p&gt;
&lt;h3&gt;公式块&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;$$
\hat{f}(\xi) = \int_{-\infty}^{\infty} f(x) e^{-2\pi i x \xi} \, dx
$$
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;$$
\hat{f}(\xi) = \int_{-\infty}^{\infty} f(x) e^{-2\pi i x \xi} , dx
$$&lt;/p&gt;
&lt;p&gt;当前使用 KaTeX 作为数学公式插件，支持的语法请参考 &lt;a href=&quot;https://katex.org/docs/supported.html&quot;&gt;KaTeX Supported Functions&lt;/a&gt;。&lt;/p&gt;
&lt;h4&gt;图片&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;![CWorld](https://cravatar.cn/avatar/1ffe42aa45a6b1444a786b1f32dfa8aa?s=200)
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://cravatar.cn/avatar/1ffe42aa45a6b1444a786b1f32dfa8aa?s=200&quot; alt=&quot;CWorld&quot;&gt;&lt;/p&gt;
&lt;h4&gt;删除线&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;~~删除线~~
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;~~删除线~~&lt;/p&gt;
&lt;h3&gt;列表&lt;/h3&gt;
&lt;p&gt;普通无序列表&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;- 1
- 2
- 3
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;1&lt;/li&gt;
&lt;li&gt;2&lt;/li&gt;
&lt;li&gt;3&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;普通有序列表&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;1. GPT-4
2. Claude Opus
3. LLaMa
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;GPT-4&lt;/li&gt;
&lt;li&gt;Claude Opus&lt;/li&gt;
&lt;li&gt;LLaMa&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;列表里可以继续嵌套语法&lt;/p&gt;
&lt;h3&gt;引用&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;&gt; 枪响，雷鸣，剑起。繁花血景。
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;枪响，雷鸣，剑起。繁花血景。&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;引用里也可以继续嵌套语法。&lt;/p&gt;
&lt;h3&gt;换行&lt;/h3&gt;
&lt;p&gt;markdown 分段落是需要空一行的。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;如果不空行
就会在一段

第一段

第二段
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;如果不空行
就会在一段&lt;/p&gt;
&lt;p&gt;第一段&lt;/p&gt;
&lt;p&gt;第二段&lt;/p&gt;
&lt;h3&gt;分隔符&lt;/h3&gt;
&lt;p&gt;如果你有写分割线的习惯，可以新起一行输入三个减号&lt;code&gt;---&lt;/code&gt; 或者星号 &lt;code&gt;***&lt;/code&gt;。当前后都有段落时，请空出一行：&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;---
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;高级技巧&lt;/h2&gt;
&lt;h3&gt;行内 HTML 元素&lt;/h3&gt;
&lt;p&gt;目前只支持部分段内 HTML 元素效果，包括 &lt;code&gt;&amp;#x3C;kdb&gt; &amp;#x3C;b&gt; &amp;#x3C;i&gt; &amp;#x3C;em&gt; &amp;#x3C;sup&gt; &amp;#x3C;sub&gt; &amp;#x3C;br&gt;&lt;/code&gt; ，如&lt;/p&gt;
&lt;h4&gt;键位显示&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;使用 &amp;#x3C;kbd&gt;Ctrl&amp;#x3C;/kbd&gt; + &amp;#x3C;kbd&gt;Alt&amp;#x3C;/kbd&gt; + &amp;#x3C;kbd&gt;Del&amp;#x3C;/kbd&gt; 重启电脑
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;使用 Ctrl + Alt + Del 重启电脑&lt;/p&gt;
&lt;h4&gt;粗斜体&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;&amp;#x3C;b&gt; Markdown 在此处同样适用，如 _加粗_ &amp;#x3C;/b&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt; Markdown 在此处同样适用，如 &lt;em&gt;加粗&lt;/em&gt; &lt;/p&gt;
&lt;h3&gt;其他 HTML 写法&lt;/h3&gt;
&lt;h4&gt;折叠块&lt;/h4&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;&amp;#x3C;details&gt;&amp;#x3C;summary&gt;点击展开&amp;#x3C;/summary&gt;它被隐藏了&amp;#x3C;/details&gt;
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;h3&gt;表格&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;| 表头1 | 表头2 |
| ----- | ----- |
| 内容1 | 内容2 |
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;| 表头1 | 表头2 |
| ----- | ----- |
| 内容1 | 内容2 |&lt;/p&gt;
&lt;h3&gt;注释&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;在引用的地方使用 [^注释] 来添加注释。

然后在文档的结尾，添加注释的内容（会默认于文章结尾渲染之）。

[^注释]: 这里是注释的内容
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;在引用的地方使用 &lt;a href=&quot;%E8%BF%99%E9%87%8C%E6%98%AF%E6%B3%A8%E9%87%8A%E7%9A%84%E5%86%85%E5%AE%B9&quot;&gt;^注释&lt;/a&gt; 来添加注释。&lt;/p&gt;
&lt;p&gt;然后在文档的结尾，添加注释的内容（会默认于文章结尾渲染之）。&lt;/p&gt;
&lt;h3&gt;To-Do 列表&lt;/h3&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;- [ ] 未完成的任务
- [x] 已完成的任务
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;[ ] 未完成的任务&lt;/li&gt;
&lt;li&gt;[x] 已完成的任务&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;符号转义&lt;/h3&gt;
&lt;p&gt;如果你的描述中需要用到 markdown 的符号，比如 _ # * 等，但又不想它被转义，这时候可以在这些符号前加反斜杠，如 &lt;code&gt;\_&lt;/code&gt; &lt;code&gt;\#&lt;/code&gt; &lt;code&gt;\*&lt;/code&gt; 进行避免。&lt;/p&gt;
&lt;pre&gt;&lt;code class=&quot;language-markdown&quot;&gt;\_不想这里的文本变斜体\_

\*\*不想这里的文本被加粗\*\*
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;预览：&lt;/p&gt;
&lt;p&gt;_不想这里的文本变斜体_&lt;/p&gt;
&lt;p&gt;**不想这里的文本被加粗**&lt;/p&gt;
&lt;hr&gt;
&lt;h2&gt;内嵌 Astro 组件&lt;/h2&gt;
&lt;p&gt;See &lt;a href=&quot;/docs/integrations/components&quot;&gt;User Components&lt;/a&gt; and &lt;a href=&quot;/docs/integrations/advanced&quot;&gt;Advanced Components&lt;/a&gt; for details.&lt;/p&gt;</content:encoded><h:img src="/@fs/root/my_blog/src/content/blog/markdown-zh/thumbnail.jpg?origWidth=4551&amp;origHeight=1590&amp;origFormat=jpg"/><enclosure url="/@fs/root/my_blog/src/content/blog/markdown-zh/thumbnail.jpg?origWidth=4551&amp;origHeight=1590&amp;origFormat=jpg"/></item></channel></rss>