LLM Journey from Next token prediction to RLHF/DPO
In this article, we will discuss the journey of LLM from pre-training to supervised finetuning, RLHF, and finally, DPO. We will focus more on the post-pretraining phase of LLM.
These are a few questions which we will try to answer in this article.
- What are LLM(basic architecture)
- How does next token prediction help LLM understand language and world knowledge?
- What is supervised finetuning, and why is it necessary?
- Why does LLM hallucinate, and what is the relationship between supervised finetuning and hallucination?
- How does RLHF work, and how does it help LLM align with human values?
- What is DPO?
What are LLM’s?
LLMs are generative models that, given a prompt, generate more tokens in the context of that prompt. LLMs are typically transformer-based deep learning models. Nowadays, they primarily utilize the decoder part of the transformer architecture. At each step, given some preceding text, they generate a token.
What is LM?
In NLP, language models aim to predict the next word given a context. In the pre-deep learning era, we used statistical models like TF-IDF, HMMs, etc. With deep learning, we pass the context to the model and get a probability distribution over the vocabulary, from which we derive the next word based on a decoding strategy.
However, after the release of the transformer architecture, these deep learning models began to scale to billions of parameters, leading to the development of what we now call large language models.
How these LLM trained
LLM is initially trained in large unlabeled text (the number of total tokens in this phase is on a scale of trillions), and nowadays, this phase training dataset also contains a synthetic dataset.
For this phase, LLM is trained in text token prediction(the token can be a character, words, or subwords, depending upon tokenizer algorithms). Given a context, LLM will predict what the next token will be. Is it simple? No, try to predict by yourself :).The model should understand semantics and syntactic knowledge of the language and the world to predict the next token.
Now, let’s understand mathematically.
Why can’t we use Pre-trained LLM as a chatbot?
In phase-1 training, we trained the model just for the next token prediction, so given the context, LLM just tried to generate it in continuation. The model doesn’t know how to answer user questions. Let’s say the user asks, “What is the capital of India?” then the model may respond, “What is the capital of India ? Answer this question in …” or it may answer. This behaviour of LLM is because, during training, models have seen these types of text. The model is good at generating continuation given the prefix but is poor at communicating.
You can think of a mighty animal that doesn’t know how to follow its master’s order. But how do we get the ChatGPT-like model? Here, SFT comes to the rescue.
What is Supervised Finetunning (SFT)?
In the SFT phase, we teach the model to answer user queries. How?
By using a demonstration dataset. For this phase, we collect high-quality datasets that contain datapoint in this format {question, answer}. Here, a data point has both questions/tasks along with their answers. A few examples of data points are below:
Example 1: “What is the capital of India ? Delhi is the capital of India”.
Example 2: “Summarize the below text {passage} {summarisation}”.
SFT dataset has many other types of NLP task classification, summarisation, open QA, chat rephrasing, and many more.
The scale of the dataset for this phase is much smaller than the first phase. Data quality is crucial for this phase, so we used human-generated or synthetic data generated from large LLM(GPT-4, Claude, etc.).
The training step for this is similar to phase -1 training. In this phase, we also try to predict the next token and use cross entropy as a loss function, which is only calculated over the response part.
This way, we teach the model to answer user queries using the knowledge model acquired in Phase 1. But can we use the model now for the bot? The answer is yes. But can we deploy this in an actual world application? It may not be because the model still has a few issues. In the next section, we will discuss them.
Why does the model Hallucinate after the SFT phase?
What is a hallucination? However, when a model answers a question with a plausible answer, it is a false and nonsensical. Why is it a big issue? When we deploy these LLMs in some critical use cases, then hallucination can be dangerous. For example, a Health LLM gives the user a wrong hallucination about some health recommendation.
But why does LLM hallucinate there after the SFT phase?
In SFT, we present a model with a question and its correct answer, and the model tries to learn or replicate that answer, but here, there are two scenarios:
Scenario 1: The model has knowledge about the answer. This example helps the model associate this question with the answer and shows how to follow instructions. This is completely fine; this is our training objective.
Scenario 2: The model doesn’t have any knowledge about the answer; we push the model to replicate the answer to that question even though the model doesn’t have any knowledge of that topic; this is kind of cheating, and this leads to hallucinations. In other words, we are teaching the model to answer a question even if the model doesn’t have knowledge about that. (kind of teaching model to tell lies).
So whenever we finetune the model after pre-training, either for SFT or domain adaption, we present new knowledge to the model, and there is always a chance of hallucinations.
In the paper “Does Finetuning LLMs on New Knowledge Encourage Hallucinations?” the author claims that when large language models are aligned via supervised finetuning, they may encounter new factual information not acquired through pre-training. It is often conjectured that this can teach the model the behaviour of hallucinating factually incorrect responses, as the model is trained to generate facts not grounded in its pre-existing knowledge.
Apart from hallucination, a significant issue with SFT is that we tell the model what to do by demonstration dataset, but we need to teach the model what not to do. But is this a significant issue? Yes, because when we deploy LLM to the real world, we want to align some human values into it, like we don’t want our model to give racist, abusive, biased answers, and it is tough to teach the model via SFT. But can we teach models to answer according to human preference?
RLHF (Reinforcement learning with human feedback).
This phase is separated into two sections: first, training a reward model and then aligning LLM with the help of the reward model.
Reward Model:
We want to train a model, which, given a prompt and LLM response, provides us with some kind of score/signal to tell how well that particular response is according to our defined preference, and then we can use this score to train LLM with the RL framework.
So to train this type of model, we need training data that have {(prompt, response), score}. This can be done in NLP by training a classification or regression model. Data is collected via a human annotator.
But given a (prompt response) getting an integer/float score is a very subjective task; for each user, the score can vary in high variation. So to overcome this, instead of asking the annotator to give a score, we give them a (prompt, response1, response2) and ask them to select which is better. And we mark it as the winner.
So, the training dataset now looks like this (prompt, winner_response, loser_response).
We can either use the last stage (SFT) model or train a separate model for the reward model. InstructGPT’s training objective is to maximize the reward score between winning and losing responses.
Now, let’s see mathematically how this model is trained.
So after training this model, given a (prompt response), we can get a score and use this score to train the RLHF model.
RLHF training
In this phase, we use reinforcement learning to fintune our model. Most models use PPO algorithms, which were introduced by OpenAI in 2017.
For this phase, we randomly select a prompt(user query) from the distribution. After this, we pass this prompt and generate a response from the model and then using (prompt, response), we get a score from the reward model. Using this reward, we change the weight of our model.
Before going into maths, let’s discuss one more topic: reward hacking. Reward hacking in RL refers to a scenario where the RL agent finds ways to achieve high rewards through unintended or undesirable behaviours rather than by solving the task as intended. Check this out for more details. https://openai.com/index/faulty-reward-functions/
In our case, this can be a case where LLM policy finds some particular token that doesn’t make sense but helps the model get a good reward score.
Now, let’s get into more details.
- Action space: vocabulary of tokens
- Observation state: Distribution of prompt
- Policy: probability distribution over all vocab tokens given an observation(prompt).
- Reward: The reward score we get from the reward model.
Now, let’s discuss loss function. We select a prompt from the collection and then pass that to LLM_T(model which parameter we are training) to get a response. Similarly, we get a response from LLM_ref(frozen model).
The first part of the loss function is the reward score; we will pass (prompt, y_T) to the reward function and get the reward score. Our objective is to maximize the reward score.
The second part of the loss function is KL divergence loss between the probability distribution of reference LLM and training LLM. But why do we use this? Because we don’t want the RLHF model to stray much from pre-trained(SFT). We have discussed reward hacking, so if we only use the reward term in the loss function, the model may just focus on getting a high reward and generate a token that gives a high reward score but is useless. We want a model that keeps its capabilities of having good knowledge of the language and the world and gives us a good reward score(mean aligns with human preference).
But how does RLHF help the model to align human value(or desired value)? This we get with the help of the reward model. We trained the reward model to get a high response score that follows guidelines and human values(like responses don’t have abusiveness and bias).
Issues with RLHF
There are two major issues with RLHF
- We need to train a reward model to get a reward score, which involves human effort to get a dataset for training the reward model, and extra computing is required to train it.
- Training an RL model to maximize this estimated reward without drifting too far from the original model is complex and unstable.
DPO(Direct Preference Optimization).
- DPO is another algorithm similar to RLHF that helps us train LLM to improve the alignment of language models to human preferences but in a different way. We don’t need a reward model and RL framework to train in DPO.
- To train the model via DPO, we need preference data. In a training example, there is (a prompt, chosen answer, rejected answer), similar to the reward model training dataset.
DPO doesn’t use RL. It directly optimizes preference data.
This area is changing really fast. I talked about the basic versions of algorithms, but there have been a lot of updates and changes to them. Finally, I’ll include a picture from the InstructGPT paper that will help sum up everything we’ve discussed.
Reference
https://arxiv.org/pdf/2203.02155.
https://huyenchip.com/2023/05/02/rlhf.html.
https://openai.com/index/faulty-reward-functions/.