# An introduction to the RL behind DeepSeek-R1
DeepSeek-R1 took us all by surprise. This would have been considered AGI 10 years ago and not only did DeepSeek release the weights, but they also blessed us with a technical report. What is perhaps most interesting is that DeepSeek did not use anything flashy. They just did the common sense thing of scaling up CoT + RL, but even their RL is not that weird. They use GRPO, which is a slight modification of PPO that makes the algorithm more compute and memory efficient.
> When you're GPU poor, ingenuity becomes the key to unraveling the mysteries of AGI
In their technical report^[https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf], they actually talk about two models: *DeepSeek-R1-zero* and *DeepSeek-R1*, which is the same as the former, except it also uses SFT on the base-model before doing RL. Since this post is about the RL behind DeepSeek-R1, we can just take a look at how DeepSeek-R1-zero was trained.
Since GRPO is the main RL algorithm that we want to understand in the end and GRPO is based on PPO, we should start with PPO. However, to properly understand PPO, we need to go back one step further, because, believe it or not, PPO itself is a slight modification of an even older and lesser known RL algorithm called TRPO^[And TRPO is a slight modification of Advantage Actor-Critic, which is itself a slight modification of subtracting a baseline, which is itself a slight modification of REINFORCE. I probably missed a few things here, but still. We stand on the should of giants, don't forget that.]. So let's start there.
## Trust Region Policy Optimization
*Trust Region Policy Optimization* (**TRPO**) is a RL algorithm that optimizes the policy directly and makes sure that policy updates are constrained. Let's take a look at the formula^[It may confuse you that we are not writing this in a way such as $J_{TRPO}(\theta)$, i.e. as an objective function. The reason for this is that TRPO is actually not an unconstrained optimization problem. While it can be rewritten into a "similar" unconstrained problem via Lagrangian Multipliers that could be phrased as $J_{TRPO}$, this defies the implementation details, so we omit it.]:
$\max_{\theta} ~\mathbb{E}_{\pi_\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} \hat{A}(s, a) \right] \quad \text{subject to} \quad \mathbb{E}_{\pi_\text{old}} \left[ \text{KL} \left( \pi_\text{old}(\cdot|s) \| \pi_\theta(\cdot|s) \right) \right] \leq \delta$
This seems really scary at first, but it is actually pretty simple if you break it down. First, let's realize that this is a constrained optimization problem, so we want to find the parameters $\theta$ which maximize the objective $\mathbb{E}_{\pi_\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} \hat{A}(s, a) \right]$, while the constraint $\mathbb{E}_{\pi_\text{old}} \left[ \text{KL} \left( \pi_\text{old}(\cdot|s) \| \pi_\theta(\cdot|s) \right) \right] \leq \delta$ still holds true.
Let's first focus on our objective. So $\pi_\theta$ is our policy, which is just a probability distribution over actions, conditioned on the state. It is parameterized over $\theta$ and it is what we update at each step. Since we are interested in LLMs, you can think of $\pi_\theta$ as an LLM, the state $s$ as the input and the action $a$ as the output of the LLM. $\pi_{old}$ then just represents a snapshot of our LLM which is frozen. At the start of each training iteration, we set $\pi_{old} \gets \pi_\theta$. As a consequence, the ratio is always $1$ in the first step of each iteration.
$A(s, a)$ is the *[[advantage function]]*^[A note to General Advantage Estimation where this idea is explained more deeply is on the list.], which quantifies the advantage that we get by choosing action $a$ in state $s$, compared to the average action that we could take. If $A(s,a)$ is positive, then that implies that taking action $a$ leads us into a new state that is better than we would expect.
Notice how in the formula we have $\hat{A}$ instead of $A$. This is not a speck of dirt on your screen, but indicates that $\hat{A}$ is an estimation of $A$, as computing the real $A$ would be computationally infeasible.
Now that we know what these symbols stand for, let's try to understand the objective function: $\max_{\theta} ~\mathbb{E}_{\pi_\text{old}} \left[ \frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} \hat{A}(s, a) \right]$
Okay that $\max_{\theta}$ is pretty obvious. It just denotes what the parameters are that we are tuning. For our LLM, these are just its weights.
Next up is that expectation $\mathbb{E}_{\pi_{old}}$, which simply indicates that we are taking a weighted average over all possible state-action pairs. The $\pi_{old}$, our current policy (i.e. the current version of the LLM), determines the weights for this average, reflecting the probabilities of taking action $a$ in state $s$ (or in our case, the probability of the output given the input to the LLM). In practice, averaging over all state-action pairs is of course intractable, so we simply approximate it by sampling.
But what is inside the expectation? Ah, that might seem a bit tricky, but it is actually quite intuitive. The ratio $\frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)}$ is greater than $1$ if choosing the action $a$ in state $s$ is more likely under $\pi_\theta$, i.e. the fine-tuned version of the LLM as opposed to our current LLM $\pi_{old}$. Obviously, it is smaller than $1$ when the opposite is true. Note that the ratio is always positive.
Now let's recall that the empirical advantage $\hat{A}(s,a)$ takes on positive values if the action $a$ is better than the average action that we could take, if we start from state $s$. It follows then that $\hat{A}(s,a)$ takes on negative values when the opposite is true. This is very important to understand for the product inside the expectation to make sense. Equipped with this knowledge, we can finally understand what $\frac{\pi_\theta(a|s)}{\pi_\text{old}(a|s)} \hat{A}(s, a)$ does. This term adjusts the advantage function by the change in probability for the actions that we take.
If $\hat{A}(s, a) > 0$, then the new policy $\pi_\theta$ is encouraged to increase the probability of action $a$ given state $s$ compared to the probability of $\pi_{old}$, s.t. $\pi_\theta(a, s) > \pi_{old}(a, s)$. If $\hat{A}(s,a) <0$ the opposite is true.
The ratio of probabilities thus effectively modulates how much weight the advantage function carries in updating the policy. That's not so hard, is it? Well, there is a problem if we just leave it at that. Actually, there are several, but they all boil down to the fact that the update steps that we take might be too large, going basically into a region in the policy space where we don't know what will happen if we apply that policy in our environment.
Luckily, we can solve this by introducing a constraint that basically tells our policy not to update too much in one batch. To quantify the difference between the old policy and the new policy (remember, they are probability distributions), we use the KL divergence. You did study information theory, didn't you, anon? Well, the KL divergence can be computed like this:
$D_{\text{KL}}\left(\pi_\text{old}(\cdot|s) \| \pi_\theta(\cdot|s)\right) = \mathbb{E}_{a \sim \pi_\text{old}} \left[ \log \frac{\pi_\text{old}(a|s)}{\pi_\theta(a|s)} \right]$
In practice, the KL Divergence is approximated locally via a taylor expansion, though we don't need to go into that now. Let's just assume that we can compute it. Since the formula demands the expected KL Divergence, we approximate the result of this expectation through sampling. We can invoke the law of large numbers to justify doing this^[Technically, there are conditions for the LLN to hold, but we can handwave those away, as the LLN is quite strong and still provides practically meaningful guarantees even when some of its assumptions are weakened.]. We can now constrain our optimization process by demanding that the expected KL divergence (think of it as quantifying the difference between the old and new policy on average) stays below some threshold $\delta$, which is a hyperparameter that we can choose. Essentially, we are telling the objective function that it should choose the best possible $\theta$ while not straying too far away from the original policy.
Believe it or not, but that is all there is to the math behind TRPO. There is of course a lot of implementation detail that I am skipping over, but your attention span is only so long and we still have a few more formulas to go over. Don't worry though, they are just slight modifications of the formula of TRPO, which you can hopefully parse now. Let's not waste any time and jump into PPO.
## Proximal Policy Optimization
*Proximal Policy Optimization*^[https://arxiv.org/abs/1707.06347] (**PPO**) is a reinforcement learning algorithm that was introduced in 2017 by Schulman et al., back when OpenAI was still mostly known for RL. It is basically a simplification of TRPO, both conceptually and computationally. Actually, let me show you the formula, so you can see for yourself:
$\mathcal{L}^\text{PPO}(\theta) = \mathbb{E}_{t} \left[ \min \left( r_t \hat{A}_t, ~ \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) ~\hat{A}_t \right) \right],$ where
- $r_t$ is our good old ratio $\frac{\pi_\theta(a_t | s_t)}{\pi_\text{old}(a_t | s_t)}$
- $\text{clip}(x,n_{lower},n_{upper})$ is the clipping function, which forces $x$ to stay inside the bounds
- $\epsilon$ is a hyperparameter for clipping
What you may immediately notice is that we are back to an unconstrained optimization problem. But wasn't the whole point to prevent the model from not deviating too far in each step? Well, the constraint is actually still there—it's just implicit! Think about what $\text{clip}$ does. It constrains how big $|r_t|$ can get. The $\min$ makes the objective even more conservative by acting as a pessimistic bound. For instance, even if $r_t$ is clipped to the lower bound, as long as $\hat{A_t}$ is negative, the $\min$ ensures that parameter changes are penalized even more strongly.
You may also notice that we are doing the expectation over $t$. For our purposes (as we are talking about LLMs), this means that we are now operating on a token-by-token level.
That's it. That's the whole deal with PPO. Okay, there are many implementation details that I have skipped. If you want to know more, read the original paper or a blog^[https://huggingface.co/blog/deep-rl-ppo] that focuses on that.
## PPO for LLMs (RLHF)
- TODO
## Making PPO go brr, even if you are GPU poor
LLMs are .. large. Training them requires a lot of compute and a lot of memory. Sadly, most of us are relatively GPU poor. Luckily, there are many people that want to train LLMs via RL and some are GPU poor too. Sometimes they invent algorithms that we GPU poors can use to not feel quite as poor. One such algorithm was recently introduced by Shao et al. and is called *Group Relative Policy Optimization* (**GRPO**)^[https://arxiv.org/abs/2402.03300]. Thanks to this algorithm, DeepSeek was able to scale up RL for LLM reasoning to the point where the finished model is a serious contender for the most capable model to date.
So let's look at the GRPO objective then:
$
J_\text{GRPO}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_\theta^\text{old}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \left( \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}
\min \left( r_{i,t} \hat{A}_{i,t}, \text{clip}(r_{i,t}, 1 - \epsilon, 1 + \epsilon) \hat{A}_{i,t} \right)
\right)
- \beta D_\text{KL}(\pi_\theta \| \pi_\text{ref}) \right]
$
where:
- $q$: A question sampled from the dataset $P(Q)$.
- $\{o_i\}_{i=1}^G$: A group of $G$ outputs sampled from the old policy $\pi_\theta^\text{old}(O|q)$.
- $r_{i,t} = \frac{\pi_\theta(o_{i,t} | q, o_{i,<t})}{\pi_\theta^\text{old}(o_{i,t} | q, o_{i,<t})}$: Probability ratio for token $t$ in output $o_i$.
- $\hat{A}_{i,t}$: The advantage estimate for token $t$ in output $o_i$, derived from group-normalized rewards.
- $\epsilon$: The clipping parameter.
- $\beta$: The KL regularization coefficient.
> 1) What
Be not afraid. We will make sense of all of this. Let's analyze what this means step by step.
First, some good news. This algorithm was specifically design with LLMs in mind, so we can mostly use terminology that we are familiar more with.
Next, let's look at with respect to what we are doing the expectation to: Questions $q$ are taken from some dataset $P(Q)$ and we generate $G$ responses by our current model. Okay, that's manageable. So what's inside the expectation then?
Well, it's basically just a term that contains PPO as well as another KL divergence as a penalty. Let's zoom in on that first term to understand what's going on.
$\frac{1}{G} \sum_{i=1}^G \left( \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}
\min \left( r_{i,t} \hat{A}_{i,t}, \text{clip}(r_{i,t}, 1 - \epsilon, 1 + \epsilon) \hat{A}_{i,t} \right) \right)$
Well, the outermost average is just averaging over the group of responses that we have generated to our question. Remember, we did generate $G$ outputs, so that definitely makes sense. So why another inner loop where we average? Well, remember that each output consists of multiple tokens and PPO works on tokens, so naturally we need to do PPO for each token and take the average. Wow, that was way simpler than it looked like.
"Now why is the KL term here again? I thought we eliminated it when going from TRPO to PPO" you may ask. Well, the reason for that is that we are fine-tuning a base model that is really big. This base model has been trained on a lot of data in the order of magnitude of trillions of tokens. We have put a lot of compute into learning meaningful representations and we want to avoid forgetting them by changing our model a lot through fine-tuning it. Hence, we add another KL term as a penalty to constrain how much our model can change. Pretty reasonable, don't you think?^[Before you ask: yes, we also do this when fine-tuning LLMs via PPO. I will add a chapter on how to do RLHF via PPO soon™.]
Okay, so what do we gain from all of this? Well, the gain comes in how we estimate $\hat{A}_{i, t}$. We just score each output with a scalar reward by an outcome reward model (another LLM) and normalize the rewards with respect to the group of outputs we sampled. That way, we eliminate the need of keeping another LLM in memory for each PPO iteration.
![[ppo vs grpo.png]]^[https://arxiv.org/pdf/2402.03300]
I hope you enjoyed this blog post. I learned a lot creating this and I hope that you can take something away from it too.
# To Be Done
- explain figure comparing PPO and GRPO
- add chapter on how PPO is used for finetuning LLMs (RLHF).
- add algorithms in pseudocode for PPO and GRPO?
- talk about some of the implementation details of GRPO?