Distributed Training in FairScale
Introduction
The last part presents deepspeed, a tool for single-node, multi-GPU training and other optimizations. This part will introduce distributed training, which is a more general and scalable way to train deep learning models.
We will introduce three popular distributed training mechanism; DataParallel (DP), DistributedDataParallel (DDP), and FullyShardedDataParallel (FSDP). Their basic ideas are similar but they have different implementations and optimizations. We will introduce them one by one.
Rank
Before diving into different distributed training frameworks, we need to introduce the concept of rank. In distributed training, each process is assigned a unique rank. The rank is used to identify the process and to communicate with other processes. The rank is usually an integer ranging from 0 to the number of processes minus 1.
In multi-node, multi-GPU training, the rank is usually a tuple (node_rank, local_rank)
. node_rank
is the rank of the node, and local_rank
is the rank of the GPU on the node.
This is the same for previously introduced deepspeed.
DataParallel (DP)
Mechanism
DataParallel (DP) is the simplest way to implement distributed training. It replicates the model on each GPU and splits the input data into multiple parts. Each GPU computes the forward and backward propagation for its own data. After that, the gradients are averaged across all GPUs and the model parameters are updated.
Implementation
To use DP, in pytorch lightning, you just need to set the strategy trainer
to dp
,
trainer = Trainer(gpus=2, accelerator='dp')
In raw PyTorch, you can also just wrap the model with torch.nn.DataParallel
,
model = nn.DataParallel(model)
When doing so, when creating tensors, you should not use to
with device ids, but just use the cuda
method,
input_tensor = input_tensor.cuda()
It will distribute the tensor based on its first dimension.
DistributedDataParallel (DDP)
Mechanism
DistributedDataParallel (DDP) is a more advanced distributed training framework. It is based on the idea of DP but has more optimizations. It uses the torch.distributed
package to communicate between processes. It can handle cases where GPUs have different amounts of memory.
Implementation
In PyTorch Lightning, similar to DP, you just need to set the strategy trainer
to ddp
,
trainer = Trainer(gpus=2, accelerator='ddp')
When using raw PyTorch, to be honest, there is more boilerplate code. You need to initialize the process group, set the rank, and wrap the model with torch.nn.parallel.DistributedDataParallel
. Here is an example,
import torch
# Initialize the process group
torch.distributed.init_process_group(backend='nccl', init_method='env://')
# Set the rank
rank = torch.distributed.get_rank()
# Wrap the model
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
It is not recommended to use DDP directly in raw PyTorch, because there might be some bugs and it is not easy to debug. You can use PyTorch Lightning, Accelerate, Fabric, FairScale, or other high-level libraries.
FullyShardedDataParallel (FSDP)
Mechanism
FullyShardedDataParallel (FSDP) is a more advanced distributed training framework. It is based on the idea of DDP but has more optimizations. It shards the model parameters and activations across GPUs. It can handle cases where GPUs are on different nodes.
Implementation
In PyTorch Lightning, you just need to set the strategy trainer
to fsdp
,
trainer = Trainer(gpus=2, accelerator='fsdp')
Using FairScale for Distributed Training
FairScale is a high-level library that provides many useful tools for distributed training. It supports DDP, FSDP, and other distributed training frameworks. It also provides many useful utilities for distributed training, such as gradient checkpointing, mixed precision, and others. It is similar to deepspeed but supports more distributed training frameworks.
The reason why we didn't introduce how to directly use deepspeed before is that deepspeed has a very absurd API for me, and I do not like it. FairScale is more elegant and easy to use.
We already presented using FairScale in PyTorch Lightning- the FSDP and DDP integration of PyTorch Lightning is based on FairScale.
To use FairScale with PyTorch,
import torch
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# process group init
dist_init(rank, world_size)
# Problem statement
model = model().to(rank)
dataloader = dataloader()
loss_ln = loss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = {"lr": 1e-4}
# Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
Then use any multi-processing library to spawn.
mp.spawn(
train,
args=(WORLD_SIZE, EPOCHS),
nprocs=WORLD_SIZE,
join=True
)
You can also activate other features in fair scale, like ZeRO, gradient checkpointing. For exmaple.
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
class CheckpointModel(nn.Module):
def __init__(self, **kwargs):
super().__init__()
torch.manual_seed(0) # make sure weights are deterministic.
self.ffn_module = nn.Sequential(
nn.Linear(32, 128),
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs)
self.last_linear = nn.Linear(32, 1)
def forward(self, input):
output = self.ffn_module(input)
return self.last_linear(output)
from fairscale.optim.grad_scaler import ShardedGradScaler
# Creates model and optimizer in default precision
model = CheckpointModel().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Refer to the official documentation for more information.
Conclusion
In this part, we introduced distributed training in FairScale and the concept of DP, DDP, and FSDP.