Large language models (LLM) have been the biggest star in generative AI, however their applications to a specific domain needs some adaption from the pretrained weights. There are two main ways to do so: 1. finetuning, which is to directly train the model on top of specific datasets, such that the model weights contain new information about the new domain. 2. RAG, which is to retrieve external domain knowledge during generation, much like an “open book” exam where the model can gain access to more accurate external information. The original Rag paper was published in 2020, it used DPR as retriever and BART as the LLM for generation. With the rapid development of LLMs, RAG technique can be adapted to bigger and better models such as LLama and GPT, however the underlying principle remains the same. In this post, I will analyze HuggingFace’s original RAG implementation and address some technical questions that are difficult to find answers to online.
HuggingFace implementation of RAG
RAG is part of HuggingFace’s transformers library. As it belongs to such a big library requiring a multitude of functionalities and abstractions, its source code is a mess and it is not easy to pinpoint the key lines of code within. However, roughly the RAG model implementation can be found and studied here, whereas ways to use it can be found here (training generator) and here (training both retriever and generator). At a high level, RAG consists of two main components: a retriever and a generator, working together to formulate an answer to a prompt. The retriever’s job is to find relevant documents within some indexed documents based on the prompt. It can use traditional techniques such as BM25 or more modern, embedding-based techniques such as DPR (Dense Passage Retrieval) or ANCE (Approximate Nearest Neighbor Negative Contrastive Estimation). The generator then integrates these documents with the original prompt to create a new, enriched prompt, which it uses to generate the final answer. Mathematically, there are 2 ways to implement the above idea:
RAG-sequence model, where we first retrieve the $k$ most relevant documents according to the initial prompt, and then use them during our generation by marginalizing over their relevance to the initial prompt:
where $x$ is the initial prompt, $y$ is the generated output, $z$ are the $k$ retrieved documents according to their relevance scores, $\eta$ and $\theta$ are the retriever and generator parameters respectively.
RAG-token model, where we retrieve the $k$ most relevant documents and marginalize over their relevance each time we want to generate a new token according to the evolving prompt containing previous generations:
where $x_i$ is the initial prompt appended by the ith generated token, $y$ is the generated output, $z$ are the $k$ retrieved documents according to their relevance scores, $\eta$ and $\theta$ are the retriever and generator parameters respectively.
How is RAG implemented in HuggingFace?
The HuggingFace implementation of RAG can be shown in the UML diagram below, where I colored the core classes in blue:
The RagModel class is the main class where most of the magic happens. Internally, it uses a LLM generator and a RagRetriever, and can be accessed via one of the 2 APIs associated with the 2 models above, RagSequenceForGeneration and RagTokenForGeneration. Taking the sequence generation pipeline as an example, the generation pipeline can be established with the following codes:
Internally, RagSequenceForGeneration utilises a RagModel class with the following structure, I skipped the less important codes and added some comments:
First, notice how we can train either both the generator and the retriever or the generator only. In case of the former, the retrieved_doc_embeds will be recalculated using the RagModel’s own ctx_encoder for retrieved documents, instead of using the encoded results directly from RagRetriever, therefore allowing backpropagation to happen on retriever embedding. Second, notice how the generator assumes an encoder-decoder architecture (BART, T5, etc.) and requires decoder_input_ids. Actually, anything apart from input_ids are optional. The decoder_input_ids will be the decoder’s own generation during inference, and is usually the target labels shifted to the right during training.
Let’s examine more closely the retrieval process. Typically, the retriever is a class of type RagRetriever with the following key codes:
Finally, let’s take a look at how the marginalization is done in RAG. First, let’s take a look at RagSequenceForGeneration, where the forward method generates a response for all the retrieved documents. The marginalization here happens during loss calculation:
Notice how the marginalization at the token generation level stays the same in each forward (and consequently get_nll()) call by adding the document relevance (doc_logprobs) to the first real generated token (second_token_scores) after BOS (begin of sentence). Then marginalization happens again at ll = ll.logsumexp(1) #logsumexp over docs when calculating losses per batch. A related discussion can be found here. For RagTokenForGeneration, we have:
Notice how the function marginalize() is present during forward() call, as opposed to RAG-sequence generation where it is absent. This is consistent with the fact that in RAG-token generation, the retrieved documents’ relevance plays a role in determining next token probability. Also notice how in get_nll(), marginalize() does not only add the document relevance (doc_scores) to the token after BOS, but to all the tokens of the generation, again consistent with RAG-token generation’s idea that each token depends on document relevance.