RL Ch5 GRPO
Introduction to GRPO
GRPO (Group Relative Policy Optimization) is a novel reinforcement learning method proposed by DeepSeek, specifically designed for large language model (LLM) reinforcement learning. It builds upon the ideas of PPO (Proximal Policy Optimization) but eliminates the need for a critic model, simplifying the overall architecture.
Core Concept of GRPO
The core idea of GRPO is to optimize the policy by comparing the outputs of the new policy with a group of outputs from the old policy. This "group relative" comparison ensures that the updates are more stable and reliable. Instead of relying on a value function to estimate the advantage, GRPO uses relative rewards computed from multiple samples.
Objective Function of GRPO
The objective function of GRPO can be formulated as:
LGRPO(θ)=E[G∑i=1(min
Where:
- is the normalized advantage computed from a group of samples. Please note that, in GRPO, we use this group-normalized reward instead of GAE. The latter requires a critic model, which we do not have in GRPO.
- and are the new and old policies, respectively.
- is a hyperparameter that controls the range of policy updates.
- is a regularization parameter that controls the KL divergence between the new policy and a reference policy . This functions the same as the entropy regulation trick we talked about in the last chapter, it is just that GRPO uses KL divergence instead of simple entropy (KL divergence is just entropy difference).
Advantages of GRPO
- No Critic Network Required: By eliminating the need for a critic network, GRPO simplifies the model architecture and reduces computational overhead.
- Stable Updates: The clipping mechanism ensures that policy updates are stable and prevent large, potentially destabilizing changes.
- Robustness to Noisy Rewards: GRPO is more robust to noisy reward signals, making it suitable for complex tasks where reward estimation can be challenging.
Implementation
GRPO does not behave as well as PPO, so we use a simpler scenario. At the same time, this also demonstrates how to use policy gradient on discreet action spaces.
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import wandb
from tqdm import trange
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
class ActorNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=1.0)
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
return self.net(x)
class GRPOTrainer:
def __init__(self, env, actor, optimizer,
clip_ratio=0.2, beta=0.001, gamma=0.99,
epochs=10, batch_size=32,
group_size=200):
self.env = env
self.actor = actor
self.optimizer = optimizer
self.clip_ratio = clip_ratio
self.beta = beta
self.gamma = gamma
self.epochs = epochs
self.batch_size = batch_size
self.group_size = group_size
self.ep_rewards = []
def _calc_returns(self, rewards, dones):
returns = []
R = 0
for r, done in zip(reversed(rewards), reversed(dones)):
R = r + self.gamma * R * (1 - done)
returns.insert(0, R)
return torch.tensor(returns, dtype=torch.float32).to(device)
def collect_rollout(self):
states, acts, rews, dones = [], [], [], []
old_logits = []
state, _ = self.env.reset()
state = torch.FloatTensor(state).to(device)
ep_rew = 0
for _ in range(self.group_size):
with torch.no_grad():
logits = self.actor(state)
dist = Categorical(logits=logits)
act = dist.sample()
next_state, rew, terminated, truncated, _ = self.env.step(act.item())
done = terminated or truncated
ep_rew += rew
states.append(state)
acts.append(act)
rews.append(rew)
dones.append(done)
old_logits.append(logits)
state = torch.FloatTensor(next_state).to(device) if not done else torch.FloatTensor(self.env.reset()[0]).to(device)
if done:
self.ep_rewards.append(ep_rew)
ep_rew = 0
returns = self._calc_returns(rews, dones)
advantages = (returns - returns.mean()) / (returns.std() + 1e-8)
return (
torch.stack(states),
torch.stack(acts),
advantages,
torch.stack(old_logits)
)
def train(self, total_updates=300):
self.actor.train()
for update in trange(total_updates):
states, actions, advantages, old_logits = self.collect_rollout()
dataset = torch.utils.data.TensorDataset(states, actions, advantages, old_logits)
loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
policy_losses = []
kl_divergences = []
for _ in range(self.epochs):
for batch in loader:
s_batch, a_batch, adv_batch, old_logits_batch = batch
new_logits = self.actor(s_batch)
old_dist = Categorical(logits=old_logits_batch.detach())
new_dist = Categorical(logits=new_logits)
logp_new = new_dist.log_prob(a_batch)
logp_old = old_dist.log_prob(a_batch).detach()
ratio = torch.exp(logp_new - logp_old)
surr1 = ratio * adv_batch
surr2 = torch.clamp(ratio, 1-self.clip_ratio, 1+self.clip_ratio) * adv_batch
policy_loss = -torch.min(surr1, surr2).mean()
kl = torch.distributions.kl.kl_divergence(old_dist, new_dist).mean()
loss = policy_loss + self.beta * kl
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.optimizer.step()
policy_losses.append(policy_loss.item())
kl_divergences.append(kl.item())
if self.ep_rewards:
avg_rew = np.mean(self.ep_rewards[-20:])
wandb.log({
"update": update,
"avg_reward": avg_rew,
"policy_loss": np.mean(policy_losses),
"kl_divergence": np.mean(kl_divergences)
})
def test(env, actor, episodes=5, render=False):
actor.eval()
for ep in range(episodes):
state, _ = env.reset()
total_rew = 0
while True:
if render:
env.render()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(device)
logits = actor(state_tensor)
act = torch.argmax(logits).item()
state, rew, terminated, truncated, _ = env.step(act)
total_rew += rew
if terminated or truncated:
print(f"Test Episode {ep+1} | Reward: {total_rew}")
break
def main():
wandb.init(project="grpo-cartpole")
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
actor = ActorNetwork(state_dim, action_dim).to(device)
optimizer = optim.Adam(actor.parameters(), lr=3e-4)
trainer = GRPOTrainer(
env=env,
actor=actor,
optimizer=optimizer,
clip_ratio=0.2,
beta=0.001,
gamma=0.99,
epochs=10,
batch_size=32,
group_size=200
)
trainer.train(total_updates=1000)
test_env = gym.make('CartPole-v1', render_mode='human')
test(test_env, actor, episodes=3, render=True)
env.close()
if __name__ == "__main__":
main()