PyTorch lightning as Framework

Introduction

This part will presents the pytorch lightning, a comprehensive solution built upon pytorch to provide more convenient and efficient way to train deep learning models.

Pytorch lightning is a systematic framework for machine learning, rather than pytorch as only a library.

The reason can be further justified by its convenient integration of many accelerators, which is key to our series of tutorials.

If you are familiar with the keras in tensorflow, you may find the pytorch lightning is similar to the keras in tensorflow.

Basic Pipeline of PyTorch lightning

PyTorch lightning segments the training process into several steps to better organize the code. The basic pipeline of PyTorch lightning is as follows,

  1. Model Definition: Define the model.
  2. Data Preparation: Prepare the dataset and dataloader.
  3. Training: Train the model.
  4. Validation: Validate the model during training.
  5. Testing: Test the trained model.

The model in pytorch lightning is self-sufficient, which means it contains all the necessary information for training, validation, and testing.

Next, we will convert the pytorch code in the previous part into pytorch lightning code.

Model Definition

Instead of defining the model as torch.nn.Module, we define the model as pl.LightningModule. The pl.LightningModule is a subclass of torch.nn.Module and provides more convenient interfaces for training.

In addition, if there is any code that moves the tensor, just remove it. The pl.LightningModule will handle it automatically.

Now firstly create the __init__ and forward functions in the model. They are mostly the same as the previous part.

class MnistModel(pl.LightningModule):

    def __init__(
        self,
        x_len: int = 28,
    ):
        # Convulsion + Attention
        super(MnistModel, self).__init__()
        # [batch, channel, x_len, x_len] -> [batch, channel, x_len, x_len]
        self.conv = pl.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1
        )
        self.relu = nn.ReLU()
        # [batch, channel, x_len, x_len] -> [batch, channel, x_len/2, x_len/2]
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        # [batch, channel, x_len/2, x_len/2] -> [batch, channel, x_len/4, x_len/4]
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [batch, channel, l, l] -> [batch, channel, l * l]
        self.flatten_before_attn = nn.Flatten(start_dim=-2)
        # [batch, channel, l * l] -> [batch, channel, l * l]
        self.attn = nn.MultiheadAttention(
            embed_dim=(x_len // 4) ** 2, num_heads=1, batch_first=True
        )
        # [batch, channel, l * l] -> [batch, channel * l * l]
        self.flatten_after_attn = nn.Flatten(start_dim=-2)
        flattened_last_dim = 64 * (x_len // 4) ** 2
        self.fc1 = nn.Linear(flattened_last_dim, flattened_last_dim * 2)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(flattened_last_dim * 2, flattened_last_dim)
        self.activation2 = nn.ReLU()
        self.fc3 = nn.Linear(flattened_last_dim, 10)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.flatten_before_attn(x)
        x, _ = self.attn(x, x, x)
        x = self.flatten_after_attn(x)
        x = self.fc1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        return x

Data Preparation

Define lightningDataModule to prepare the dataset and dataloader. It is a more convenient way to prepare the dataset and dataloader.

class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = './data', batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def prepare_data(self):
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

The stage parameter in the setup function is used to distinguish the training, validation, and testing stages. It is a string of 'fit', 'test', or None.

Define Optimizer

Use configure_optimizers to define the optimizer.

class MnistModel(pl.LightningModule):

    # ...

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

If using multiple optimizers and schedulers, you can return a list of optimizers and schedulers.

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
    return [optimizer], [scheduler]

If there were just one, list brackets can be omitted.

Define Steps

Define the training, validation, and testing steps.

class MnistModel(pl.LightningModule):

    # ...

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('test_loss', loss)
        return loss

Setup Logger with Wandb

The pytorch lightning needs a logger to log the training process. Here we use the wandb, introduced in the last part, as the logger.

run = wandb.init(project='demo', name='lightning-mnist-demo', config={
    #...
})
logger = pl.loggers.WandbLogger(run)

Use self.log to call the log function of the logger.

In addition, specifically for wandb, you can use wandb.run instead of run.

Train the Model

Create a trainer,

trainer = pl.Trainer(
    logger=logger,
    max_epochs=5,
    log_every_n_steps=10,
)

Train the model, validation will be automatically performed during training.

model = MnistModel()
data_module = MnistDataModule()
trainer.fit(model, data_module)

Test the Model

Call trainer.test to test the model.

trainer.test(model, data_module)

Don't forget to call wandb.finish to finish the wandb run.

Conclusion

After running the code, you will find the training process is more organized and convenient. The pytorch lightning provides a more systematic way to train the model.

In addition, you may find out that training is already faster than before. This is because the pytorch lightning has integrated many accelerators, such as the GPU, TPU, and so on, and they usually outperform the pytorch.

Callbacks

After presenting the basic pipeline of pytorch lighting. Now, we will introduce the most important concept in pytorch lightning, the callbacks.

To use callbacks, use fill the callbacks parameter in the Trainer.

trainer = pl.Trainer(
    logger=logger,
    max_epochs=5,
    callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
)

The callbacks are used to customize the training process. Here, this part will familiarize you with the most commonly used built-in callbacks and custom callbacks.

Built-in Callbacks

ModelCheckpoint

The ModelCheckpoint callback is used to save the model during training. It can save the best model, the last model, or the model at a specific epoch.

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints',
    filename='mnist-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

EarlyStopping

The EarlyStopping callback is used to stop the training when the monitored metric stops improving.

early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min',
)

LearningRateMonitor

The LearningRateMonitor callback is used to log the learning rate during training.

lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')

Custom Callbacks

You can also define custom callbacks by subclassing pl.Callback.

class CustomCallback(pl.Callback):
    def on_train_start(self, trainer, pl_module):
        print('Training is started!')

    def on_train_end(self, trainer, pl_module):
        print('Training is ended!')

Conclusion

This part has introduced the pytorch lightning, a comprehensive solution built upon pytorch to provide more convenient and efficient way to train deep learning models.

Following parts of the series will use pytorch lighting as the major framework, then introduce multiple libraries for accelerating the training process, starting with deepspeed in the next part.