RLHF Derivation

5 minute read

Published:

Personal derivation of some important RLHF theorems.

Policy Gradient

We aim to maximize the expected return

\[J\left(\pi_\theta\right)=\underset{\tau \sim \pi_\theta}{\mathrm{E}}[R(\tau)].\]

I prefer to use the expression in integral form

\[J\left(\pi_ \theta\right)=\int_{\tau}P_\theta \left( \tau \right) R\left( \tau \right),\] \[\begin{align*} \nabla_\theta J\left(\pi_\theta \right) &=\nabla_\theta \int_{\tau}P_\theta\left(\tau \right)R\left(\tau \right) \\ &=\int_{\tau}\nabla_\theta P_\theta\left(\tau \right)R\left(\tau \right) \\ &=\int_{\tau}P_\theta\left(\tau \right)\nabla_\theta \log{P_\theta\left(\tau \right)} R\left(\tau \right). \end{align*}\]

The dynamic sampling in DAPO actually alters the sampling distribution.

Reward-to-Go

\(\begin{align*} \nabla_\theta J\left(\pi_\theta \right) &=\int_{\tau}P_\theta\left(\tau \right)\nabla_\theta \log{\prod_{t=0}^{T} \pi_\theta\left(a_t\mid s_t\right) } R\left(\tau \right) \\ &=\int_{\tau}P_\theta\left(\tau \right)\nabla_\theta \sum_{t=0}^{T} \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(\tau \right) \\ &=\int_{\tau}P_\theta\left(\tau \right)\nabla_\theta \sum_{t=0}^{T} \log{\pi_\theta\left(a_t\mid s_t\right) } \sum_{t^\prime=0}^{T} R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \\ &=\int_{\tau}P_\theta\left(\tau \right)\sum_{t=0}^{T}\nabla_\theta\log{\pi_\theta\left(a_t\mid s_t\right) } \sum_{t^\prime=0}^{T} R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \\ &=\boxed{ \int_{\tau}P_\theta\left(\tau \right)\sum_{t=0}^{T}\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } \sum_{t^\prime=0}^{t-1} R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right)} + \int_{\tau}P_\theta\left(\tau \right)\sum_{t=0}^{T}\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } \sum_{t^\prime=t}^{T} R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \end{align*}\)

We will prove that the boxed term equals to \(0\).

Expected Grad-Log-Prob (EGLP) lemma:

\[\int_x P_\theta(x)=1.\] \[\nabla_\theta \int_x P_\theta(x)=\nabla_\theta 1=0.\] \[\begin{aligned} 0 & =\nabla_\theta \int_x P_\theta(x) \\ & =\int_x \nabla_\theta P_\theta(x) \\ & =\int_x P_\theta(x) \nabla_\theta \log P_\theta(x). \end{aligned}\] \[\int_{\tau}P_\theta\left(\tau \right)\sum_{t=0}^{T}\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } \sum_{t^\prime=0}^{t-1} R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) = \int_{\tau}P_\theta\left(\tau \right)\sum_{t=0}^{T} \sum_{t^\prime=0}^{t-1} \nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right).\]

For a specific \(t\) and \(t^\prime\), we get \(\int_{\tau}P_\theta\left(\tau \right)\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right)\). For trajactories that has the same \(a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}\), they have the same \(\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right)\). So, we can add their probbilities directly, which is \(P_\theta\left(a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right)\). In this way, we transform above term to:

\[\int_{a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}}P_\theta\left(a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right)\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right).\]

\(\int_{a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}}\) indicates sampling \(a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}\) from \(\pi_\theta\). Note that \(a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}\) are still variables; \(t\) and \(t^\prime\) are constants.

Using the marginal distribution, we get:

\[\begin{aligned} & \int_{a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}}P_\theta\left(a_t, s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right)\nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \\ & = \int_{a_t, s_t} \int_{s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}} P_\theta\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) P_\theta\left(a_t, s_t \mid s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) \nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \\ & = \int_{s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}} P_\theta\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) \int_{a_t, s_t} P_\theta\left(a_t, s_t \mid s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) \nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \\ & = \int_{s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}} P_\theta\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \int_{a_t, s_t} P_\theta\left(a_t, s_t \mid s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) \nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } \\ & = \int_{s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1}} P_\theta\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) R\left(s_{t^\prime}, a_{t^\prime}, s_{t^\prime + 1}\right) \int_{s_t} P_\theta\left(s_t \mid s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right)\boxed{ \int_{a_t} \pi_\theta\left(a_t \mid s_t \right) \nabla_\theta \log{\pi_\theta\left(a_t\mid s_t\right) } } \end{aligned}\]

According to EGLP, the boxed part equals to \(0\).

When \(t^\prime > t\), we cannot do \(P_\theta\left(a_t, s_t \mid s_{t^\prime}, a_{t^\prime}, s_{s^\prime+1}\right)=\pi_\theta\left(a_t \mid s_t \right) P_\theta\left(s_t \mid s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right)\). We can only do \(P_\theta\left(a_t, s_t \mid s_{t^\prime}, a_{t^\prime}, s_{s^\prime+1}\right)=P_\theta\left(a_t \mid s_t, s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right) P_\theta\left(s_t \mid s_{t^\prime}, a_{t^\prime}, s_{t^\prime+1} \right)\).

Baseline

When we add a baseline \(b\left(s_t\right)\) to \(R\)

\[\begin{aligned} \Delta \nabla_\theta J\left(\pi_\theta \right)=\sum_{t=0}^{T} \int_\tau P_\theta \left(\tau\right) \nabla_\theta \log{\pi_\theta \left(a_t \mid s_t \right)} b\left(s_t\right). \end{aligned}\]

Similar to the method in Reward-to-Go,

\[\begin{aligned} \Delta \nabla_\theta J\left(\pi_\theta \right) & =\sum_{t=0}^{T} \int_{a_t, s_t} P_\theta \left(a_t, s_t\right) \nabla_\theta \log{\pi_\theta \left(a_t \mid s_t \right)} b\left(s_t\right) \\ & = \sum_{t=0}^{T} \int_{s_t} P_\theta\left(s_t \right) \int_{a_t} P_\theta \left(a_t \mid s_t\right) \nabla_\theta \log{\pi_\theta \left(a_t \mid s_t \right)} b\left(s_t\right) \\ & = \sum_{t=0}^{T} \int_{s_t} P_\theta\left(s_t \right) b\left(s_t\right) \boxed{ \int_{a_t} P_\theta \left(a_t \mid s_t\right) \nabla_\theta \log{\pi_\theta \left(a_t \mid s_t \right)}}. \end{aligned}\]

According to EGLP, the boxed part equals to \(0\).

PPO

PPO employs the following objective for policy optimization:

\[J_{\rm PPO}\left(\pi_\theta\right)=\mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\theta_{\text {old }}}(\cdot \mid x)}\left[\frac{1}{|y|} \sum_{t=1}^{|y|} \min \left(w_t(\theta) \widehat{A}_t, \operatorname{clip}\left(w_t(\theta), 1-\varepsilon, 1+\varepsilon\right) \widehat{A}_t\right)\right],\]

where the importance ratio of the token \(y_t\) is defined as \(w_t(\theta)=\frac{\pi_\theta\left(y_t \mid x, y_{<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_t \mid x, y_{<t}\right)}\).

When dealing with RL, we usually need to analyze the gardient of the objective. It is because we are not calculating but estimating gradients in RL. We need to ensure that our estimated gradients are close to the theoretical ones. The gradient of PPO objective is (omit clipping for brevity):

\[\begin{aligned} \nabla J_{\rm PPO}\left(\pi_\theta\right) &=\mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\theta_{\text {old }}}(\cdot \mid x)} \left[\frac{1}{|y|} \sum_{t=1}^{|y|}\nabla_\theta w_t(\theta) \widehat{A}_t\right] \\ &=\mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\theta_{\text {old }}}(\cdot \mid x)} \left[\frac{1}{|y|} \sum_{t=1}^{|y|} w_t(\theta) \widehat{A}_t \nabla_\theta \log{w_t(\theta)} \right] \\ &=\mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_{\theta_{\text {old }}}(\cdot \mid x)} \left[\frac{1}{|y|} \sum_{t=1}^{|y|} \frac{\pi_\theta\left(y_t \mid x, y_{<t}\right)}{\pi_{\theta_{\text {old }}}\left(y_t \mid x, y_{<t}\right)} \widehat{A}_t \nabla_\theta \log{\pi_\theta\left(y_t \mid x, y_{<t}\right)} \right]. \end{aligned}\]

Accodring to Reward-to-Go and Baseline, the perfect policy gradient is

\[\begin{aligned} \nabla_\theta J\left(\pi_\theta\right) &=\int_\tau P_\theta(\tau) \sum_{t=0}^T A_t \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right) \\ &=\int_\tau P_{\theta_{old}}(\tau) \frac{P_\theta(\tau)}{P_{\theta_{old}}(\tau)}\sum_{t=0}^T A_t \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right). \end{aligned}\]

\(\frac{\pi_\theta\left(y_t \mid x, y_{<t}\right)}{\pi_{\theta_{\text{old}}}}\) is not always equals to \(\frac{P_\theta(\tau)}{P_{\theta_{old}}(\tau)}\). And GRPO also has this problem. That is why GSPO proposed sequence-level importance ratio.

Tags: