RAG(检索增强生成)现在基本成了处理长文档的标准流程,但是它问题也很明显:把检索到的所有文本段落统统塞进LLM,这样会导致token数量爆炸,处理速度慢不说,还费钱。
meta提出了一个新的方案REFRAG:与其让LLM处理成千上万个token,不如先用轻量级编码器(比如RoBERTa)把每个固定大小的文本块压缩成单个向量,再投影到LLM的token嵌入空间。
他在论文中说可以提速30倍,我们来看看是怎么做的:
输入序列长度大幅缩短,每个文本块变成一个向量而不是几十个token。计算可以重用,块嵌入在检索阶段就能预计算好,避免重复编码。注意力机制变得更稀疏,LLM现在只需要关注块级别的信息,而不是每个token。
更有意思的是,REFRAG保留了自回归解码的灵活性。它可以在上下文的任意位置进行压缩,通过一个学习策略将压缩嵌入和真实token嵌入混合使用。在实际应用中,基于强化学习的选择策略会挑选少数"重要"块展开为完整token序列,其他块继续保持压缩状态。这和传统RAG形成鲜明对比——后者会把每个检索段落的每个token都完整输入解码器,在无关或冗余文本上浪费大量计算资源。
REFRAG的整体流程非常好理解。检索到的上下文先被拆分成固定大小的块,送入轻量编码器;强化学习策略决定哪些块需要展开为完整token;最后解码器接收查询token和块嵌入的混合输入。
技术架构剖析REFRAG模型结合了decoder-only LLM(LLaMA、GPT这类)和轻量级编码器(RoBERTa)。给定查询(token x1…xq)和检索上下文(xq+1…xT),系统将上下文拆分为L个大小为k的块。编码器处理每个块Ci生成向量ci = M_enc(Ci),线性投影层φ随后将每个ci映射到解码器嵌入空间,产生与词向量同维度的向量ei^(enc)。
解码器输入变为原始查询嵌入{e1, …, eq}加上投影块嵌入集合{e1^enc, …, eL^enc}。效果就是上下文被"压缩"了:s个上下文token变成L=s/k个嵌入。Lin等人的实验表明,解码器输入大约减少k倍。
一个强化学习策略(小型神经网络)接着动态选择部分块,将其完整token嵌入用来替代单向量形式。这种"选择性展开"只把最关键的文本注入解码器,其余部分维持压缩状态。最终结果是内存和注意力中的嵌入数量显著减少,生成速度明显提升。
工作流程解析REFRAG系统的实际运行过程包含六个主要步骤:
检索阶段使用标准检索器(FAISS索引或DRAGON+)从大型语料库获取查询的top-K相关段落,和常规RAG没什么区别。每个段落被切分为固定长度块,通常16-32个。
块编码环节将每个块送入轻量编码器。一般取[CLS]或池化输出作为块向量,这些向量可以预计算或缓存来提升效率。
投影步骤通过可学习线性层将块向量映射到LLM的token嵌入空间(维度d_emb)。投影后每个块用一个d_emb维嵌入表示,解码器可以像处理普通token一样注意这些块向量。
策略感知部分用强化学习训练的策略网络检查所有块嵌入(可能包括查询),选择一个子集(比如25%)进行展开。未选中的块保持单一嵌入形式。策略训练目标是最大化答案质量,使用负的下一token困惑度作为奖励信号。
展开操作针对策略选择的每个块,将其单一嵌入替换为原始token嵌入序列,需要把这些token嵌入传入解码器。
解码生成最后,解码器LLM接收混合序列:查询token + 部分块嵌入 + 展开token。它正常应用自注意和交叉注意,但由于许多块以单一嵌入形式存在,输入序列大大缩短,LLM随后自回归生成答案。
也就是说REFRAG压缩上下文块来减小输入大小,智能选择哪些块真正需要完整token细节,然后展开这少数几个。这套流程在实际应用中效果相当显著:16倍压缩率(k=16)下,REFRAG的首token时间(time-to-first-token)比完整上下文的基线LLaMA快约16.5倍,准确性几乎没有损失。
下面是个简化的PyTorch实现示例,展示如何为文档QA任务构建REFRAG核心逻辑。这里用Hugging Face Transformers来编码上下文块、投影向量、并通过inputs_embeds送入因果LM。这只是概念验证,完整REFRAG系统需要按上述方法训练编码器/投影器和强化学习策略。
import torch from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM # Load models (small examples for illustration) encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") encoder_model = AutoModel.from_pretrained("bert-base-uncased") decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2") decoder_model = AutoModelForCausalLM.from_pretrained("gpt2") # Example retrieved passages (context) for a query context_docs = [ "Albert Einstein developed the theory of relativity in 1905.", "He received the Nobel Prize in 1921 for his services to theoretical physics.", "Later, he introduced the equation E = mc^2 in 1905." ] # 1. Encode each chunk into a vector (CLS token embedding) chunk_embeddings = [] for doc in context_docs: inputs = encoder_tokenizer(doc, return_tensors="pt", truncation=True, max_length=128) with torch.no_grad(): outputs = encoder_model(**inputs) cls_vec = outputs.last_hidden_state[:, 0, :] # [CLS] token embedding chunk_embeddings.append(cls_vec) # shape [1, hidden_size] # 2. Project chunk embeddings to decoder embedding size proj = torch.nn.Linear(chunk_embeddings[0].size(-1), decoder_model.config.n_embd) projected_chunks = [proj(vec) for vec in chunk_embeddings] # list of [1, n_embd] # 3. Prepare query tokens and embeddings query = "Who won the Nobel Prize for physics in 1921?" q_inputs = decoder_tokenizer(query, return_tensors="pt") with torch.no_grad(): q_embeds = decoder_model.transformer.wte(q_inputs["input_ids"]) # [1, len_q, n_embd] # 4. Combine query embeddings and projected chunk embeddings # (Here we pretend *no* RL expansion: use all chunks as embeddings.) combined_embeds = torch.cat([q_embeds] + projected_chunks, dim=1) # [1, total_len, n_embd] # 5. Generate answer with the decoder using inputs_embeds generated = decoder_model.generate(inputs_embeds=combined_embeds, max_length=50) print(decoder_tokenizer.decode(generated[0], skip_special_tokens=True))
代码逻辑很直接:用BERT编码器压缩每个段落,线性层映射到GPT-2嵌入维度,调用generate(..., inputs_embeds=...)让GPT-2处理这些向量。LLM把这些向量当作特殊"伪token"处理。在完整REFRAG设置中,你需要用强化学习策略选择的实际token嵌入来替换或增强projected_chunks。
核心组件详解检索器部分通常使用稠密检索器或向量索引(FAISS)从语料库获取K个相关段落。REFRAG不改变这个步骤,检索文本的方式和RAG一样。Lin等人在Wikipedia/CommonCrawl(4亿段落)上使用DRAGON+稠密检索器。
编码器采用轻量级架构(RoBERTa)处理文本块,输出固定大小向量(通常用[CLS] token或均值池化)作为块嵌入。持续预训练(CPT)期间,编码器会和投影层一起训练,学习压缩信息并最小化损失。
投影层是可学习线性层,将块嵌入转换为LLM的token嵌入空间,确保维度匹配。这让解码器注意块向量时就像处理普通词向量。
解码器(生成器)使用标准decoder-only LLM(LLaMA、GPT-2)自回归生成答案。REFRAG中解码器词汇表和结构不变,只是看到更短的输入:查询token加上块向量(以及展开的token)。
选择性压缩策略通过小型策略网络(MLP)决定哪些块保持单一嵌入,哪些展开为完整token。策略用REINFORCE训练:当展开某些块能降低困惑度(提高下一token准确性)时给予奖励。随时间推移,策略学会将"易压缩"上下文保持嵌入形式,只展开关键部分。
持续预训练(CPT)在部署前用专门的预训练方案对齐编码器与解码器。首先学习从嵌入重建文本:冻结解码器,编码块并训练投影让解码器能恢复原始token。然后通过课程学习逐步增加难度(从1个块到2个,再到更多)。这确保编码器真正捕获块含义。最后对编码器/解码器端到端微调(RAG QA、对话等任务),同时训练强化学习策略。
解码机制在标准RAG中,每个输出token都会注意所有查询和段落token。REFRAG解码时,解码器注意查询token + L个块嵌入 + 展开token。由于许多块以单一向量存在,注意力工作量大幅减少。Lin等人观察到RAG上下文呈现"块对角"稀疏性(不同块间交互少),REFRAG利用这点跳过大部分冗余token注意力。
LangChain集成思路LangChain已经有Embeddings、VectorStore、LLM、Chains等抽象,REFRAG需要在此基础上添加几个新组件:
文档块化功能,将检索文档拆分为固定大小块。编码器+投影器模块,压缩每个块为单个嵌入并投影到LLM嵌入空间。策略网络,决定哪些块展开(传入完整token)或保持压缩。LLM的混合上下文输入处理,包含查询+压缩块嵌入+展开完整token。编码器、投影器、策略的训练流程,让压缩+展开在生成时表现良好。
LangChain默认不直接支持在LLM上下文中混合任意嵌入与完整token嵌入(特别是"可展开"策略)。但HF Transformers支持inputs_embeds,可以构建这个功能。
需要构建的组件包括:ChunkEncoder接收文本块返回嵌入,Projector线性层映射块嵌入空间到LLM token嵌入空间,PolicyNetwork给定查询嵌入与块嵌入输出每个块展开分数或选择掩码,Retriever与VectorStore获取候选文档并可选拆分为块,自定义链/LLM包装器构造混合输入、处理展开、然后用inputs_embeds调用LLM。
from typing import List, Tuple import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM from langchain.embeddings import Embeddings from langchain.vectorstores import FAISS from langchain.schema import Document from langchain.llms.base import LLM # 1. ChunkEncoder + Projector + Policy Network class ChunkEncoder(nn.Module): def __init__(self, encoder_model_name: str, chunk_size: int): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(encoder_model_name) self.encoder = AutoModel.from_pretrained(encoder_model_name) self.chunk_size = chunk_size def chunkify(self, text: str) -> List[str]: # simple split by whitespace / fixed tokens, you could use tokenizer toks = self.tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0] chunks = [] for i in range(0, toks.size(0), self.chunk_size): chunk_ids = toks[i : i + self.chunk_size] chunk_text = self.tokenizer.decode(chunk_ids, skip_special_tokens=True) chunks.append(chunk_text) return chunks def forward(self, chunk_texts: List[str]) -> torch.Tensor: # returns embeddings of shape (num_chunks, encoder_hidden_size) encodings = self.tokenizer(chunk_texts, padding=True, truncation=True, return_tensors="pt") outputs = self.encoder(**encodings) # e.g. use .pooler_output or CLS token # if model has pooler: if hasattr(outputs, "pooler_output"): return outputs.pooler_output # (batch, hidden_size) else: # fallback: mean pooling last = outputs.last_hidden_state # (batch, seq, hidden) return last.mean(dim=1) class Projector(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, chunk_embs: torch.Tensor) -> torch.Tensor: return self.linear(chunk_embs) # maps into LLM emb dim class PolicyNetwork(nn.Module): def __init__(self, emb_dim: int, hidden_size: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(emb_dim * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1) ) def forward(self, query_emb: torch.Tensor, chunk_embs: torch.Tensor) -> torch.Tensor: # query_emb: (emb_dim,), chunk_embs: (num_chunks, emb_dim) # produce a score per chunk q = query_emb.unsqueeze(0).expand(chunk_embs.size(0), -1) # (num_chunks, emb_dim) inp = torch.cat([q, chunk_embs], dim=1) scores = self.net(inp).squeeze(-1) # (num_chunks,) return scores # 2. Using LangChain components class RefragChain: def __init__( self, retriever, # a LangChain retriever chunk_encoder: ChunkEncoder, projector: Projector, policy: PolicyNetwork, llm_model_name: str, llm_tokenizer_name: str, expand_ratio: float = 0.25, # fraction of chunks to expand max_new_tokens: int = 128 ): self.retriever = retriever self.chunk_encoder = chunk_encoder self.projector = projector self.policy = policy self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_name) self.llm = AutoModelForCausalLM.from_pretrained(llm_model_name) self.expand_ratio = expand_ratio self.max_new_tokens = max_new_tokens def answer(self, query: str) -> str: # 1. Retrieve documents docs: List[Document] = self.retriever.get_relevant_documents(query) # 2. Split into chunks chunk_texts = [] chunk_doc_map = [] # to know which chunk came from which doc for d in docs: chunks = self.chunk_encoder.chunkify(d.page_content) for c in chunks: chunk_texts.append(c) chunk_doc_map.append(d) # 3. Encode chunks & project chunk_embs = self.chunk_encoder(chunk_texts) # (C, enc_dim) projected = self.projector(chunk_embs) # (C, llm_emb_dim) # 4. Compute query embedding (using chunk_encoder as proxy or separate) # optionally, use same encoder with torch.no_grad(): q_enc = self.chunk_encoder([query]) # (1, enc_dim) q_proj = self.projector(q_enc).squeeze(0) # (llm_emb_dim,) # 5. Policy: select top-k chunks to expand scores = self.policy(q_proj, chunk_embs) # (C,) k = max(1, int(self.expand_ratio * len(chunk_texts))) topk_idx = torch.topk(scores, k).indices.tolist() # 6. Prepare LLM input embeddings # a) tokenize query q_tok = self.llm_tokenizer(query, return_tensors="pt", truncation=True) q_tok_ids = q_tok["input_ids"] q_embeds = self.llm.get_input_embeddings()(q_tok_ids) # (1, q_len, llm_emb_dim) # b) For each chunk: if in topk, tokenize fully, else use projected embedding chunk_input_embeds_list = [] for i, c_text in enumerate(chunk_texts): if i in topk_idx: # expand fully tok = self.llm_tokenizer(c_text, return_tensors="pt", truncation=True) emb = self.llm.get_input_embeddings()(tok["input_ids"]) # (1, chunk_len, emb_dim) else: # compressed: treat projection as one "special token embedding" emb = projected[i].unsqueeze(0).unsqueeze(1) # (1,1, emb_dim) chunk_input_embeds_list.append(emb) # concatenate embeddings: query + all chunk embeddings/expanded all_chunk_embeds = torch.cat(chunk_input_embeds_list, dim=1) # e.g. (1, total_chunkified_length, emb_dim) full_input_embeds = torch.cat([q_embeds, all_chunk_embeds], dim=1) # 7. Generate out = self.llm.generate( inputs_embeds=full_input_embeds, max_new_tokens=self.max_new_tokens ) answer = self.llm_tokenizer.decode(out[0], skip_special_tokens=True) return answer
from langchain.llms.base import LLM from langchain.schema import LLMResult class RefragLLM(LLM): def __init__(self, refrag_chain: RefragChain): self.refrag_chain = refrag_chain def _call(self, prompt: str, stop: List[str] = None) -> str: return self.refrag_chain.answer(prompt) @property def _identifying_params(self): return {"refrag": True}
实际部署时需要注意几个关键点:训练/加载预训练的ChunkEncoder + Projector + Policy,示例中用的是随机或未训练组件,REFRAG需要预训练让压缩嵌入保留相关信息。对齐嵌入空间,确保投影后的嵌入在LLM注意时表现良好。设计RL奖励信号来训练策略,衡量展开特定块如何提高生成质量或降低困惑度。调整学习和压缩率,平衡k(块大小)和展开比例,理解性能权衡。
性能评估REFRAG在多个长上下文任务上进行了全面评估,涵盖开放域QA、多选推理、对话和摘要等场景。
数据集覆盖相当广泛,包括RAG基准(NaturalQuestions、FEVER、TQA等)、常识推理(HellaSwag、Winogrande等)和对话/摘要任务。实验同时模拟强检索场景(只有真正top-K相关段落)和弱检索场景(大量候选中包含不相关段落)。
基线对比涵盖具有完整上下文或截断到匹配token数的LLaMA-2,以及之前的长上下文方法如CEPE和REPLUG。RAG QA任务上,研究者对LLaMA进行微调确保公平比较。
评估指标从两个维度衡量:推理速度用首token时间(TTFT)和每迭代token时间(TTIT),以及整体吞吐量(token/秒);准确性通过验证集困惑度和QA任务准确率(精确匹配/F1等)来评估。
实验结果相当令人印象深刻。REFRAG在大多数情况下实现巨大加速且准确性无损。在超长上下文的16倍压缩(k=16)下,REFRAG的TTFT比LLaMA快约16.5倍。k=32时TTFT达到约32.9倍LLaMA(≈30.85倍报告值),与论文声称的30.85倍加速基本吻合。困惑度和下游准确性基本保持不变。
在检索质量较弱的任务中,REFRAG甚至超越LLaMA性能,因为相同延迟预算下它能包含更多上下文。论文表3和图4显示REFRAG在16个RAG任务的强/弱设置中都能匹配或超越LLaMA。消融实验还表明,REFRAG的强化学习驱动选择性压缩优于简单启发式方法(比如丢弃低困惑度块)。
总的来说,REFRAG在大上下文场景下实现了超过30倍的生成加速且无准确性下降,有效地在延迟受限情况下将LLM上下文扩展约16倍。
局限性与改进空间REFRAG最大的问题就是训练开销。因为他需要额外预训练和微调,编码器、投影层、策略都要单独训练(通过重建和课程学习任务),复杂度远超即插即用的RAG。
并且系统复杂性明显增加。新增编码器、投影层、策略网络等组件,依赖强化学习进行块选择,工程工作量大且可能需要精细调优(策略学习率等)。参考实现涉及很多超参数(块大小k、展开比例p、学习率等)。
虽然REFRAG在检索不完美时有帮助,但依然需要检索到相对相关的块。如果检索完全失效,单纯压缩无法弥补缺失的知识,也就是说他还是要以来第一步检索的准确性。
最后就是极高压缩率可能导致信息丢失,强化学习策略虽然能通过展开关键部分来缓解,但速度与保真度间仍需要权衡。实验显示REFRAG在16-32倍压缩范围内保持准确性,超出这个范围性能会下降(论文图10显示过高压缩导致更大损失)。
总结REFRAG作为延迟敏感、知识密集任务的专门解决方案,用额外的模型训练和系统复杂性换取显著加速,这个trade-off在很多场景下是值得的。如论文所说,"REFRAG…无需修改LLM架构",意味着一旦编码器/投影对齐,就能与任何解码器模型配合使用。
从长远看,随着编码器预训练技术和强化学习策略的持续改进,REFRAG的训练复杂度可能会降低,而性能收益会进一步提升。这种思路也为其他需要处理长序列的AI应用提供了参考。
论文:
https://avoid.overfit.cn/post/2675cf1a065745c9bc44d755d5141a0d
作者:DhanushKumar