RL Ch3 Actor-Critic
Actor-Critic
Actor-Critic is an algorithm that uses two models. Actor, the policy network, is responsible for selecting actions given a state, while Critic, the value network, evaluates the actions taken by the Actor by estimating the value function. This combination allows the algorithm to benefit from both policy-based and value-based methods.
Actor
The Actor represents the policy , which is usually parameterized by a neural network. Given a state , the network outputs a probability distribution over actions, or parameters of a distribution (for example, mean and standard deviation in the case of a Gaussian policy). The objective is to maximize the expected return:
where is the cumulative reward. The policy parameters are updated using gradient ascent based on the policy gradient:
with representing the action-value function.
This is identical to the policy gradient method, using the value given out by the critic network as the advantage function.
Critic
The Critic estimates the value function, which can be either the state-value function or the action-value function . Its main role is to provide feedback to the Actor regarding the quality of actions taken.
The Critic minimizes the mean squared error,
where:
- is the immediate reward,
- is the discount factor,
- is the next state.
updating its parameters accordingly.
Of course, if you use , then the Critic model is identical to Q-learning. That is, instead of estimating the state-value function , the Critic estimates the action-value function directly. In this case, the Critic is trained to minimize the loss,
Combing Actor and Critic
In an Actor-Critic framework, both the Actor and the Critic are trained simultaneously and interact with each other during the learning process. Here’s how the two models work together:
-
Action Selection:
At time step , given the current state , the Actor selects an action according to its policy . -
Environment Interaction:
The selected action is executed in the environment, which then returns a reward and the next state . -
Critic Evaluation:
The Critic evaluates the quality of the action by estimating the value function. This can be either the state-value or the action-value .
-
If using , the Temporal Difference (TD) error (also known as GAE generalized advantage estimation) is computed as:
-
If using , a similar TD error (or Bellman error) is computed based on the Q-learning target:
-
Critic Update:
The Critic’s parameters are updated to minimize the squared TD error:
The gradient descent update for the Critic is:
where is the learning rate for the Critic.
- Actor Update:
The Actor is updated using the policy gradient method, with the advantage (previously, reward, but now, the value) serving as the weight. The update rule for the Actor’s parameters is:
where is the Actor’s learning rate. This update increases the probability of actions that yield a positive advantage (i.e., better-than-expected outcomes) and decreases it for actions with a negative advantage.
Implementation
Using actor-critic algorithm leads to faster convergence. However, the task is still too complicated with gravity.
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
import wandb
from tqdm import trange
device = "cpu"
print(f"Using device: {device}")
class ActorNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.mean_layer = nn.Linear(128, action_dim)
self.logstd_layer = nn.Linear(128, action_dim)
self.apply(self.init_weights)
@staticmethod
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
nn.init.constant_(m.bias, 0.0)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
mean = self.mean_layer(x)
logstd = self.logstd_layer(x)
logstd = torch.clamp(logstd, min=-20, max=2)
return mean, logstd
class CriticNetwork(nn.Module):
def __init__(self, state_dim):
super(CriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.value_layer = nn.Linear(128, 1)
self.apply(self.init_weights)
@staticmethod
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
nn.init.constant_(m.bias, 0.0)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
value = self.value_layer(x)
return value
class ActorCriticTrainer:
def __init__(self, env, actor_model, critic_model, actor_optimizer, critic_optimizer, gamma=0.99):
self.env = env
self.actor_model = actor_model
self.critic_model = critic_model
self.actor_optimizer = actor_optimizer
self.critic_optimizer = critic_optimizer
self.gamma = gamma
self.episode_rewards = []
def compute_returns(self, rewards):
discounted_rewards = []
running_return = 0
for r in reversed(rewards):
running_return = r + self.gamma * running_return
discounted_rewards.insert(0, running_return)
return torch.tensor(discounted_rewards, dtype=torch.float32).to(device)
def train(self, num_episodes=5000):
for episode in trange(num_episodes):
state, _ = self.env.reset()
state = torch.FloatTensor(state).to(device)
done = False
truncated = False
rewards = []
log_probs = []
values = []
while not done and not truncated:
mean, logstd = self.actor_model(state)
std = torch.exp(logstd)
dist = Normal(mean, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum()
action_clamped = torch.clamp(action, min=-2, max=2)
value = self.critic_model(state)
next_state, reward, done, truncated, _ = self.env.step(action_clamped.cpu().numpy())
rewards.append(reward)
log_probs.append(log_prob)
values.append(value.squeeze())
state = torch.FloatTensor(next_state).to(device)
if len(rewards) == 0:
continue
returns = self.compute_returns(rewards)
values = torch.stack(values)
advantages = returns - values.detach()
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
log_probs = torch.stack(log_probs)
actor_loss = -torch.mean(log_probs * advantages)
critic_loss = nn.MSELoss()(values, returns)
self.critic_optimizer.zero_grad()
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), 1.0)
self.critic_optimizer.step()
self.actor_optimizer.zero_grad()
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), 1.0)
self.actor_optimizer.step()
total_reward = sum(rewards)
self.episode_rewards.append(total_reward)
wandb.log({
"episode": episode,
"reward": total_reward,
"actor_loss": actor_loss.item(),
"critic_loss": critic_loss.item(),
"mean_std": std.mean().item()
})
if episode % 100 == 0:
print(f"Episode {episode}, Reward: {total_reward:.1f}")
def main():
wandb.init(project="rl-fixed")
env = gym.make("Pendulum-v1", g=0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
actor_model = ActorNetwork(state_dim, action_dim).to(device)
critic_model = CriticNetwork(state_dim).to(device)
actor_optimizer = optim.Adam(actor_model.parameters(), lr=3e-4)
critic_optimizer = optim.Adam(critic_model.parameters(), lr=1e-3)
trainer = ActorCriticTrainer(env, actor_model, critic_model, actor_optimizer, critic_optimizer, gamma=0.99)
trainer.train()
torch.save(actor_model.state_dict(), "actor_fixed.pth")
torch.save(critic_model.state_dict(), "critic_fixed.pth")
test(actor_model)
def test(actor_model):
env = gym.make("Pendulum-v1", render_mode="human", g=0)
state, _ = env.reset()
total_reward = 0
while True:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(device)
mean, logstd = actor_model(state_tensor)
action = torch.clamp(Normal(mean, torch.exp(logstd)).sample(), min=-2, max=2)
next_state, reward, done, _, _ = env.step(action.cpu().numpy())
total_reward += reward
state = next_state
if done:
break
print(f"Test Reward: {total_reward:.1f}")
env.close()
if __name__ == "__main__":
main()