RL Ch3 Actor-Critic

Actor-Critic

Actor-Critic is an algorithm that uses two models. Actor, the policy network, is responsible for selecting actions given a state, while Critic, the value network, evaluates the actions taken by the Actor by estimating the value function. This combination allows the algorithm to benefit from both policy-based and value-based methods.

Actor

The Actor represents the policy , which is usually parameterized by a neural network. Given a state , the network outputs a probability distribution over actions, or parameters of a distribution (for example, mean and standard deviation in the case of a Gaussian policy). The objective is to maximize the expected return:

where is the cumulative reward. The policy parameters are updated using gradient ascent based on the policy gradient:

with representing the action-value function.

This is identical to the policy gradient method, using the value given out by the critic network as the advantage function.

Critic

The Critic estimates the value function, which can be either the state-value function or the action-value function . Its main role is to provide feedback to the Actor regarding the quality of actions taken.

The Critic minimizes the mean squared error,

where:

  • is the immediate reward,
  • is the discount factor,
  • is the next state.

updating its parameters accordingly.

Of course, if you use , then the Critic model is identical to Q-learning. That is, instead of estimating the state-value function , the Critic estimates the action-value function directly. In this case, the Critic is trained to minimize the loss,

Combing Actor and Critic

In an Actor-Critic framework, both the Actor and the Critic are trained simultaneously and interact with each other during the learning process. Here’s how the two models work together:

  1. Action Selection:
    At time step , given the current state , the Actor selects an action according to its policy .

  2. Environment Interaction:
    The selected action is executed in the environment, which then returns a reward and the next state .

  3. Critic Evaluation:

The Critic evaluates the quality of the action by estimating the value function. This can be either the state-value or the action-value .

  • If using , the Temporal Difference (TD) error (also known as GAE generalized advantage estimation) is computed as:

  • If using , a similar TD error (or Bellman error) is computed based on the Q-learning target:

  • Critic Update:
    The Critic’s parameters are updated to minimize the squared TD error:

The gradient descent update for the Critic is:

where is the learning rate for the Critic.

  1. Actor Update:
    The Actor is updated using the policy gradient method, with the advantage (previously, reward, but now, the value) serving as the weight. The update rule for the Actor’s parameters is:

where is the Actor’s learning rate. This update increases the probability of actions that yield a positive advantage (i.e., better-than-expected outcomes) and decreases it for actions with a negative advantage.

Implementation

Using actor-critic algorithm leads to faster convergence. However, the task is still too complicated with gravity.

import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
import wandb
from tqdm import trange

device = "cpu"
print(f"Using device: {device}")

class ActorNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.mean_layer = nn.Linear(128, action_dim)
        self.logstd_layer = nn.Linear(128, action_dim)
        self.apply(self.init_weights)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            nn.init.constant_(m.bias, 0.0)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean_layer(x)
        logstd = self.logstd_layer(x)
        logstd = torch.clamp(logstd, min=-20, max=2)
        return mean, logstd

class CriticNetwork(nn.Module):
    def __init__(self, state_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.value_layer = nn.Linear(128, 1)
        self.apply(self.init_weights)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            nn.init.constant_(m.bias, 0.0)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        value = self.value_layer(x)
        return value

class ActorCriticTrainer:
    def __init__(self, env, actor_model, critic_model, actor_optimizer, critic_optimizer, gamma=0.99):
        self.env = env
        self.actor_model = actor_model
        self.critic_model = critic_model
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.gamma = gamma
        self.episode_rewards = []

    def compute_returns(self, rewards):
        discounted_rewards = []
        running_return = 0
        for r in reversed(rewards):
            running_return = r + self.gamma * running_return
            discounted_rewards.insert(0, running_return)
        return torch.tensor(discounted_rewards, dtype=torch.float32).to(device)

    def train(self, num_episodes=5000):
        for episode in trange(num_episodes):
            state, _ = self.env.reset()
            state = torch.FloatTensor(state).to(device)
            done = False
            truncated = False
            rewards = []
            log_probs = []
            values = []

            while not done and not truncated:
                mean, logstd = self.actor_model(state)
                std = torch.exp(logstd)
                dist = Normal(mean, std)
                action = dist.sample()
                log_prob = dist.log_prob(action).sum()
                action_clamped = torch.clamp(action, min=-2, max=2)
                value = self.critic_model(state)

                next_state, reward, done, truncated, _ = self.env.step(action_clamped.cpu().numpy())

                rewards.append(reward)
                log_probs.append(log_prob)
                values.append(value.squeeze())
                state = torch.FloatTensor(next_state).to(device)

            if len(rewards) == 0:
                continue

            returns = self.compute_returns(rewards)
            values = torch.stack(values)

            advantages = returns - values.detach()
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            log_probs = torch.stack(log_probs)

            actor_loss = -torch.mean(log_probs * advantages)

            critic_loss = nn.MSELoss()(values, returns)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), 1.0)
            self.critic_optimizer.step()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), 1.0)
            self.actor_optimizer.step()

            total_reward = sum(rewards)
            self.episode_rewards.append(total_reward)
            wandb.log({
                "episode": episode,
                "reward": total_reward,
                "actor_loss": actor_loss.item(),
                "critic_loss": critic_loss.item(),
                "mean_std": std.mean().item()
            })

            if episode % 100 == 0:
                print(f"Episode {episode}, Reward: {total_reward:.1f}")

def main():
    wandb.init(project="rl-fixed")
    env = gym.make("Pendulum-v1", g=0)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    actor_model = ActorNetwork(state_dim, action_dim).to(device)
    critic_model = CriticNetwork(state_dim).to(device)

    actor_optimizer = optim.Adam(actor_model.parameters(), lr=3e-4)
    critic_optimizer = optim.Adam(critic_model.parameters(), lr=1e-3)

    trainer = ActorCriticTrainer(env, actor_model, critic_model, actor_optimizer, critic_optimizer, gamma=0.99)
    trainer.train()

    torch.save(actor_model.state_dict(), "actor_fixed.pth")
    torch.save(critic_model.state_dict(), "critic_fixed.pth")
    test(actor_model)

def test(actor_model):
    env = gym.make("Pendulum-v1", render_mode="human", g=0)
    state, _ = env.reset()
    total_reward = 0
    while True:
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).to(device)
            mean, logstd = actor_model(state_tensor)
            action = torch.clamp(Normal(mean, torch.exp(logstd)).sample(), min=-2, max=2)
        next_state, reward, done, _, _ = env.step(action.cpu().numpy())
        total_reward += reward
        state = next_state
        if done:
            break
    print(f"Test Reward: {total_reward:.1f}")
    env.close()

if __name__ == "__main__":
    main()