Deepspeed for ZeRO Optimization

Introduction

ZeRO (Zero Redundancy Optimizer) is a memory optimization technique that allows training models with billions of parameters on a single GPU. It is a combination of three techniques: ZeRO-Offload, ZeRO-Stage, and ZeRO-Memory.

ZeRO-Offload is a technique that offloads optimizer states to the host CPU.

ZeRO-Stage is a technique that partitions the optimizer states across the GPUs.

ZeRO-Memory is a technique that partitions the model across the GPUs.

Deepspeed is a library that implements ZeRO optimization for PyTorch. It provides a simple API to enable ZeRO optimization in your PyTorch model. It also supports other optimizations like mixed precision training, gradient accumulation, and more.

Construct Another Model

The previous model was not very complex and it was fast to train even without any optimization. Let's construct a more complex model to see the effect of Deepspeed.

The task here is IMDB sentiment analysis. IMDB input is a movie review and the output is the sentiment of the review, positive or negative.

Boot Up Logger

We will use Weights and Biases for logging.

run = wandb.init(
    project="demo",
    name="imdb-demo",
    tags=["demo"],
    config={
        "batch_size": 32,
        "lr": 2e-4,
        "max_epochs": 5,
        "dropout": 0.2,
        "gemma": 0.6,
        "lr_step": 1,
        "embed_dim": 256,
        "num_heads": 8,
        "num_layers": 6,
        "max_len": 1024,
        "hidden_dim": 1024,
        "num_workers": 8,
    }
)
logger = WandbLogger(run)

Data Preparation

We will use the Hugging Face NLP tool chains. First, initialize the tokenizer, we use FacebookAI/roberta-large.

from transformers import RobertaTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-large")

We define the dataset in a separate file or else there would be a multi-process hitch.

import torch
from torch.utils.data.dataset import Dataset
from transformers import PreTrainedTokenizer
from torch import Tensor

class ImdbDataset(Dataset):

    def __init__(self, hf_dataset, tokenizer: PreTrainedTokenizer):
        super(ImdbDataset, self).__init__()
        self.data = hf_dataset
        self.tokenizer = tokenizer

    def __getitem__(self, index: int):
        return (
            Tensor(
                self.tokenizer.encode(
                    (self.data[index]["text"]).replace("<br/>", ""),
                )
            ).to(dtype=torch.long),
            Tensor([self.data[index]["label"]]).to(dtype=torch.long)
        )

    def __len__(self) -> int:
        return len(self.data)

Then define the lightning module, under the same file as the dataset.

class ImdbDataModule(pl.LightningDataModule):

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int,
        num_workers: int,
    ):
        super(ImdbDataModule, self).__init__()
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        dataset = load_dataset("stanfordnlp/imdb", cache_dir="./data/")

    def setup(self, stage: str) -> None:
        dataset = load_dataset("stanfordnlp/imdb", cache_dir="./data/")
        test = dataset["test"]
        test = test.train_test_split(test_size=0.5)
        dataset["test"] = test["train"]
        dataset["validation"] = test["test"]

        if stage == "fit" or stage is None:
            self.train_data = ImdbDataset(dataset["train"], self.tokenizer)
            self.val_data = ImdbDataset(dataset["validation"], self.tokenizer)
        else:
            self.test_data = ImdbDataset(dataset["test"], self.tokenizer)

        def collate_fn(batch):
            x, y = zip(*batch)
            return torch.nn.utils.rnn.pad_sequence(x, batch_first=True), torch.cat(y)

        self.collate_fn = collate_fn

    def test_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers,
            shuffle=True,
        )

Model Setup

This is a simple transformer encoder only model for classification.

class ImdbModel(pl.LightningModule):

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        embed_dim = 512,
        hidden_dim = 2048,
        max_len = 1024,
        num_heads = 8,
        num_layers = 8,
        dropout = 0.1,
    ):
        super(ImdbModel, self).__init__()
        self.d_model = embed_dim
        self.tokenizer = tokenizer
        self.embed = nn.Sequential(
            nn.Embedding(tokenizer.vocab_size, embed_dim),
            PositionalEncoding(embed_dim, dropout, max_len),
        )
        encoder_norm = nn.LayerNorm(embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),
        )

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, src_key_padding_mask=None):
        x = self.embed(x)
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        x = x.mean(dim=1)
        x = self.classifier(x)
        return x

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=wandb.config.lr)
        sched = torch.optim.lr_scheduler.StepLR(optim, step_size=wandb.config.lr_step, gamma=wandb.config.gemma)
        return [optim], [sched]

    def training_step(self, batch, batch_idx):
        x, y = batch
        padding_mask = (x == self.tokenizer.pad_token_id)
        y_hat = self(x, padding_mask)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        padding_mask = (x == self.tokenizer.pad_token_id)
        y_hat = self(x, padding_mask)
        loss = self.loss_fn(y_hat, y)
        self.log("val_loss", loss)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        padding_mask = (x == self.tokenizer.pad_token_id)
        y_hat = self(x, padding_mask)
        loss = self.loss_fn(y_hat, y)
        self.log("test_loss", loss)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("test_acc", acc)
        return loss

Speed Monitor

Define a call back to setup a speed meter for the model.

from typing import Any
from pytorch_lightning import LightningModule, Trainer
from time import time

class SpeedCounterCallback(pl.Callback):

    def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
        self.batch_start_time = time()

    def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None:
        self.batch_end_time = time()
        wandb.log({"second_per_batch": self.batch_end_time - self.batch_start_time})

Train the Model

Since we are using multi-processing data loader, we need to wrap the model with the python-styled main function.

if __name__ == "__main__":
    trainer = pl.Trainer(
        max_epochs=wandb.config.max_epochs,
        logger=logger,
        callbacks=[
            SpeedCounterCallback(),
        ],
    )
    model = ImdbModel(
        tokenizer,
        run.config.embed_dim,
        run.config.hidden_dim,
        run.config.max_len,
        run.config.num_heads,
        run.config.num_layers,
        run.config.dropout,
    )
    from imdb_dataset import ImdbDataModule
    data_module = ImdbDataModule(
        tokenizer,
        run.config.batch_size,
        run.config.num_workers,
        run.config.max_len
    )
    trainer.fit(model, data_module)

Since deepspeed only supports CuDA, we should move the model to use GPU. On a single RTX 3090, it's roughly three minutes per epoch, point two seconds per batch, way slower than the previous model due to the larger size.

ZeRO optimization, if used with a higher stage, can lower the training speed while decrease the memory usage. Just keep a record here that without any optimization, the GPU memory used is roughly 7.4 GB.

The performance of this model can be measured with an accuracy of 0.7 (capable of reaching nearly 0.9 if more epochs are allowed), which is not bad.

Use Deepspeed with PyTorch Lightning

Using deepspeed in pytorch lightning is still an experiment feature, so are every methods we are going to introduce. The reason why this essay concerns how you can use the methods with PyTorch Lightning is that their original libraries are sometimes absurdly complex to use, and not even user friendly. Deepspeed as an example, if you need to work in a multi-GPU environment, you have to use a special launcher to start the training.

To use deepspeed in PyTorch Lightning, you need to install the deepspeed library and the torch library.

However, a thing to take note is that deepspeed can not be used in a jupyter notebook environment, so use the following command to convert the notebook to a python script.

jupyter nbconvert --to script main.ipynb

The only change you need to take with PyTorch Lightning is setting the strategy in trainer to deepspeed_stage_1, then run the script.

You may find that the training is actually a tiny bit slower. Indeed, deepspeed is optimized for multi-GPU. The extra overheads are not worth it for a single GPU.

If using two cards,

trainer = pl.Trainer(
    max_epochs=wandb.config.max_epochs,
    logger=logger,
    callbacks=[
        SpeedCounterCallback(),
        pl.callbacks.LearningRateMonitor(logging_interval="step"),
    ],
    accelerator="GPU",
    devices=2,
    strategy="deepspeed_stage_1"
)

The speed almost doubled, without any harm on the performance, albeit the memory usage also doubled for having model copies on each card. However, if you set the stage higher, you may notice that the speed goes down, but the memory usage goes down as well.

For reference, it is not about one and a half minutes per epoch, with 7.3 GB for the main card and 6.1 GB for the other card.

So if you have multiple cards, deepspeed can put them into good use. In addition, for training very large model like modern LLMs, deepspeed is a must since it allows to separate everything among the cards, thus allowing the training of models with billions of parameters, instead of a necessity to have one super card with tons of memory.

Deepspeed also supports many techniques, but using only the string as strategy is enough for most cases. If you need to use more advanced techniques, you can use a config dictionary.

from lightning.pytorch import Trainer
from lightning.pytorch.strategies import DeepSpeedStrategy

deepspeed_config = {
    "zero_allow_untested_optimizer": True,
    "optimizer": {
        "type": "OneBitAdam",
        "params": {
            "lr": 3e-5,
            "betas": [0.998, 0.999],
            "eps": 1e-5,
            "weight_decay": 1e-9,
            "cuda_aware": True,
        },
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "last_batch_iteration": -1,
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 100,
        },
    },
    "zero_optimization": {
        "stage": 2,  # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
        "offload_optimizer": {"device": "cpu"},  # Enable Offloading optimizer state/calculation to the host CPU
        "contiguous_gradients": True,  # Reduce gradient fragmentation.
        "overlap_comm": True,  # Overlap reduce/backward operation of gradients for speed.
        "allgather_bucket_size": 2e8,  # Number of elements to all gather at once.
        "reduce_bucket_size": 2e8,  # Number of elements we reduce/allreduce at once.
    },
}

trainer = Trainer(accelerator="GPU", devices=4, strategy=DeepSpeedStrategy(config=deepspeed_config), precision=16)

About the valid fields and their respect functions, you can refer to the deepspeed documentation.

Now we will introduce perhaps the most important concept in deepspeed, ZeRO optimization, aka, zero redundancy optimizer.

ZeRO Optimization

ZeRO is a method proposed in this paper. To put it more simply, ZeRO reduces the unnecessary copies made by the optimizer and the model, thus reducing the memory usage and increasing the training speed. In addition, offload is also implemented, which is a way to load parts of the less frequently used data to the host memory or even disk, so as to reduce the need for larger GPU memory.

ZeRO is divided into three stages. As the order goes up, the optimization becomes more aggressive, causing the speed of training to decrease while decreasing the memory usage.

ZeRO Stage 0

ZeRO Stage 0 is no ZeRO optimization at all, which means the traditional way to do multi-GPU training.

Traditionally, each card holds a copy of the model and the optimizer. When the optimizer is updated, the gradients are sent to the main card, where the optimizer is updated, then the updated optimizer is sent back to each card. This is a very memory consuming process, since each card holds a copy of the model and the optimizer.

ZeRO Stage 1

ZeRO Stage 1 improves the memory usage by partitioning the optimizer state among the cards. This means that each card holds a part of the optimizer state, and when the optimizer is updated, the optimizer state is sent to the card with the corresponding part of the optimizer state, then the optimizer is updated.

Because optimizer states must stay in GPU, stage one does not allow offload.

ZeRO Stage 2

Compared with the first level, stage two also partitions the gradients. A card will hold only a part of the optimizer, the gradients that this part of the optimizer is responsible for, and the model.

From stage two, offload is allowed because gradients can be temporarily stored in the host memory or even the disk, without any immediate need to be in the GPU memory.

ZeRO Stage 3

Everything will be partitioned. So backward and forward operation will take more time, but the memory usage will be the lowest.

ZeRO Infinity

The so-called ZeRO Infinity is actually stage three, but also offloading everything aggressively to the disk. When doing calculation, the data will be loaded from the disk to the memory, then the calculation will be done, then the data will be offloaded to the disk again.

Conclusion

In this part, we firstly built a more complex model classifies IMDB sentiment analysis as the benchmark of our future runs. Then we introduced how to use deepspeed in PyTorch Lightning, and finally we introduced the basic mechanism in ZeRO optimization.

Deepspeed is a very powerful tool for training large models, especially in a multi-GPU environment. It can reduce the memory usage and increase the training speed. However, it is not recommended to use it in a single GPU environment, since the overheads are not worth it.

It has many other useful techniques, to which you can refer on its documentation, albeit we only focused on ZeRO optimization in this part.