RL Ch4 PPO

PPO

Actor-Critic improves the performance of a policy network, and there are more ways to do so. Another improvement based on Actor-Critic is PPO (Proximal Policy Optimization).

PPO is designed to optimize the policy more efficiently by introducing a constraint to ensure that each policy update does not change the policy too drastically. This is done by clipping the objective function to prevent large, unstable updates, while still allowing for enough flexibility for effective learning.

A Recapitulation

We would use the following loss for a typical policy network.

Where is the advantage. In the vanilla policy gradient, this equals the reward, and in the actor-critic architecture, it is the value given out by the critic network.

PPO differentiates from this method by introducing a clipped objective that prevents excessively large policy updates. The idea is to ensure the new policy doesn't deviate too far from the old policy, even if the objective function suggests a large update.

The Clipped Surrogate Objective

The key modification in PPO is the use of a clipped surrogate objective to prevent large updates. The clipped objective is defined as:

Where:

  • is the probability ratio between the new policy and the old policy.
  • is the advantage estimate at time step .
  • is a hyperparameter that defines the clipping range (usually small, such as 0.1 or 0.2).
  • is the clipped ratio that limits how much the probability ratio can change.

This modification ensures that if the new policy makes a drastic change, the objective function will be "penalized" by the clipping, discouraging large updates. On the other hand, if the change is small (i.e., the ratio is within the clipping bounds), the objective behaves as a standard policy gradient loss.

How to Obtain ?

An engineering problem here is that how to get the old distribution. We only need to firstly, run multiple inferences in a batch without gradient, then for each entry in the batch, update the model. For the first entry, the would stay as one, but for the later ones, it's value is correct.

Entropy Regulation

A common trick for policy gradient learning is to increase the loss by . Where ia a small number, typically . And is the entropy. The entropy term measures the randomness or uncertainty of the policy. A higher entropy means the policy is more stochastic, encouraging the agent to explore a wider range of actions. On the other hand, a lower entropy indicates a more deterministic policy.

The entropy term is defined as:

This term quantifies how uncertain the policy is about which action to choose in a given state. In reinforcement learning, encouraging exploration by increasing entropy can prevent the agent from converging to suboptimal deterministic policies too early, ensuring that it continues to explore different actions to discover better strategies.

Implementation

PPO yields excellent stability and convergence, so we enable the gravity for this model.

import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
import wandb
from tqdm import trange

device = "cpu"
print(f"Using device: {device}")

class ActorNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.mean_layer = nn.Linear(128, action_dim)
        self.logstd_layer = nn.Linear(128, action_dim)
        self.apply(self.init_weights)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            nn.init.constant_(m.bias, 0.0)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean_layer(x)
        logstd = self.logstd_layer(x)
        logstd = torch.clamp(logstd, min=-20, max=2)
        return mean, logstd

class CriticNetwork(nn.Module):
    def __init__(self, state_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.value_layer = nn.Linear(128, 1)
        self.apply(self.init_weights)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            nn.init.constant_(m.bias, 0.0)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        value = self.value_layer(x)
        return value

class PPOTrainer:
    def __init__(self, env, actor_model, critic_model, actor_optimizer, critic_optimizer,
                 gamma=0.99, gae_lambda=0.95, clip_param=0.2, n_epochs=4, batch_size=64):
        self.env = env
        self.actor_model = actor_model
        self.critic_model = critic_model
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_param = clip_param
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.episode_rewards = []

    def compute_gae(self, rewards, values, dones, next_value):
        values = values + [next_value]
        advantages = []
        gae = 0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + self.gamma * values[t+1] * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
            advantages.insert(0, gae)
        return torch.tensor(advantages, dtype=torch.float32).to(device)

    def train(self, num_updates=1000, n_steps=2048):
        for update in trange(num_updates):
            states, actions, rewards, dones, old_log_probs, values = [], [], [], [], [], []
            state, _ = self.env.reset()
            state = torch.FloatTensor(state).to(device)
            episode_reward = 0

            for _ in range(n_steps):
                with torch.no_grad():
                    mean, logstd = self.actor_model(state)
                    std = torch.exp(logstd)
                    dist = Normal(mean, std)
                    action = dist.sample()
                    log_prob = dist.log_prob(action).sum()
                    value = self.critic_model(state).squeeze()

                next_state, reward, done, truncated, _ = self.env.step(action.cpu().numpy())
                episode_reward += reward

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                dones.append(done or truncated)
                old_log_probs.append(log_prob)
                values.append(value)

                state = torch.FloatTensor(next_state).to(device) if not (done or truncated) else torch.FloatTensor(self.env.reset()[0]).to(device)

                if done or truncated:
                    self.episode_rewards.append(episode_reward)
                    episode_reward = 0

            with torch.no_grad():
                next_value = self.critic_model(state).squeeze().item()

            states = torch.stack(states)
            actions = torch.stack(actions)
            old_log_probs = torch.stack(old_log_probs)
            values = torch.stack(values).cpu().numpy()

            advantages = self.compute_gae(rewards, values.tolist(), dones, next_value)
            returns = advantages + torch.tensor(values, dtype=torch.float32).to(device)
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            dataset = torch.utils.data.TensorDataset(states, actions, old_log_probs, returns, advantages)
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

            for _ in range(self.n_epochs):
                for batch in dataloader:
                    b_states, b_actions, b_old_log_probs, b_returns, b_advantages = batch

                    mean, logstd = self.actor_model(b_states)
                    std = torch.exp(logstd)
                    dist = Normal(mean, std)
                    new_log_probs = dist.log_prob(b_actions).sum(dim=1)
                    entropy = dist.entropy().mean()

                    ratio = torch.exp(new_log_probs - b_old_log_probs)
                    surr1 = ratio * b_advantages
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * b_advantages
                    actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * entropy

                    current_values = self.critic_model(b_states).squeeze()
                    critic_loss = nn.MSELoss()(current_values, b_returns)

                    self.actor_optimizer.zero_grad()
                    actor_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), 0.5)
                    self.actor_optimizer.step()

                    self.critic_optimizer.zero_grad()
                    critic_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), 0.5)
                    self.critic_optimizer.step()

            if len(self.episode_rewards) > 0:
                avg_reward = np.mean(self.episode_rewards[-10:])
                wandb.log({
                    "update": update,
                    "avg_reward": avg_reward,
                    "actor_loss": actor_loss.item(),
                    "critic_loss": critic_loss.item(),
                    "entropy": entropy.item()
                })
                if update % 10 == 0:
                    print(f"Update {update}, Avg Reward: {avg_reward:.1f}")

def main():
    wandb.init(project="rl")
    env = gym.make("Pendulum-v1")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    actor_model = ActorNetwork(state_dim, action_dim).to(device)
    critic_model = CriticNetwork(state_dim).to(device)

    actor_optimizer = optim.Adam(actor_model.parameters(), lr=3e-4)
    critic_optimizer = optim.Adam(critic_model.parameters(), lr=3e-4)

    trainer = PPOTrainer(
        env=env,
        actor_model=actor_model,
        critic_model=critic_model,
        actor_optimizer=actor_optimizer,
        critic_optimizer=critic_optimizer,
        gamma=0.99,
        gae_lambda=0.95,
        clip_param=0.2,
        n_epochs=4,
        batch_size=64
    )
    trainer.train(num_updates=500, n_steps=2048)

    torch.save(actor_model.state_dict(), "ppo_actor.pth")
    torch.save(critic_model.state_dict(), "ppo_critic.pth")
    test(actor_model)

def test(actor_model):
    env = gym.make("Pendulum-v1", render_mode="human")
    state, _ = env.reset()
    total_reward = 0
    while True:
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).to(device)
            mean, logstd = actor_model(state_tensor)
            action = torch.clamp(Normal(mean, torch.exp(logstd)).sample(), min=-2, max=2)
        next_state, reward, done, _, _ = env.step(action.cpu().numpy())
        total_reward += reward
        state = next_state
        if done:
            break
    print(f"Test Reward: {total_reward:.1f}")
    env.close()

if __name__ == "__main__":
    main()