PyTorch lightning as Framework
Introduction
This part will presents the pytorch lightning, a comprehensive solution built upon pytorch to provide more convenient and efficient way to train deep learning models.
Pytorch lightning is a systematic framework for machine learning, rather than pytorch as only a library.
The reason can be further justified by its convenient integration of many accelerators, which is key to our series of tutorials.
If you are familiar with the keras in tensorflow, you may find the pytorch lightning is similar to the keras in tensorflow.
Basic Pipeline of PyTorch lightning
PyTorch lightning segments the training process into several steps to better organize the code. The basic pipeline of PyTorch lightning is as follows,
- Model Definition: Define the model.
- Data Preparation: Prepare the dataset and dataloader.
- Training: Train the model.
- Validation: Validate the model during training.
- Testing: Test the trained model.
The model in pytorch lightning is self-sufficient, which means it contains all the necessary information for training, validation, and testing.
Next, we will convert the pytorch code in the previous part into pytorch lightning code.
Model Definition
Instead of defining the model as torch.nn.Module
, we define the model as pl.LightningModule
. The pl.LightningModule
is a subclass of torch.nn.Module
and provides more convenient interfaces for training.
In addition, if there is any code that moves the tensor, just remove it. The pl.LightningModule
will handle it automatically.
Now firstly create the __init__
and forward
functions in the model. They are mostly the same as the previous part.
class MnistModel(pl.LightningModule):
def __init__(
self,
x_len: int = 28,
):
# Convulsion + Attention
super(MnistModel, self).__init__()
# [batch, channel, x_len, x_len] -> [batch, channel, x_len, x_len]
self.conv = pl.Conv2d(
in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1
)
self.relu = nn.ReLU()
# [batch, channel, x_len, x_len] -> [batch, channel, x_len/2, x_len/2]
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
)
# [batch, channel, x_len/2, x_len/2] -> [batch, channel, x_len/4, x_len/4]
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# [batch, channel, l, l] -> [batch, channel, l * l]
self.flatten_before_attn = nn.Flatten(start_dim=-2)
# [batch, channel, l * l] -> [batch, channel, l * l]
self.attn = nn.MultiheadAttention(
embed_dim=(x_len // 4) ** 2, num_heads=1, batch_first=True
)
# [batch, channel, l * l] -> [batch, channel * l * l]
self.flatten_after_attn = nn.Flatten(start_dim=-2)
flattened_last_dim = 64 * (x_len // 4) ** 2
self.fc1 = nn.Linear(flattened_last_dim, flattened_last_dim * 2)
self.activation1 = nn.ReLU()
self.fc2 = nn.Linear(flattened_last_dim * 2, flattened_last_dim)
self.activation2 = nn.ReLU()
self.fc3 = nn.Linear(flattened_last_dim, 10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.flatten_before_attn(x)
x, _ = self.attn(x, x, x)
x = self.flatten_after_attn(x)
x = self.fc1(x)
x = self.activation1(x)
x = self.fc2(x)
x = self.activation2(x)
x = self.fc3(x)
return x
Data Preparation
Define lightningDataModule
to prepare the dataset and dataloader. It is a more convenient way to prepare the dataset and dataloader.
class MnistDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './data', batch_size: int = 64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def prepare_data(self):
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
The stage parameter in the setup
function is used to distinguish the training, validation, and testing stages. It is a string of 'fit', 'test', or None.
Define Optimizer
Use configure_optimizers
to define the optimizer.
class MnistModel(pl.LightningModule):
# ...
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
If using multiple optimizers and schedulers, you can return a list of optimizers and schedulers.
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
return [optimizer], [scheduler]
If there were just one, list brackets can be omitted.
Define Steps
Define the training, validation, and testing steps.
class MnistModel(pl.LightningModule):
# ...
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log('val_loss', loss)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log('test_loss', loss)
return loss
Setup Logger with Wandb
The pytorch lightning needs a logger to log the training process. Here we use the wandb, introduced in the last part, as the logger.
run = wandb.init(project='demo', name='lightning-mnist-demo', config={
#...
})
logger = pl.loggers.WandbLogger(run)
Use self.log
to call the log function of the logger.
In addition, specifically for wandb, you can use wandb.run
instead of run
.
Train the Model
Create a trainer,
trainer = pl.Trainer(
logger=logger,
max_epochs=5,
log_every_n_steps=10,
)
Train the model, validation will be automatically performed during training.
model = MnistModel()
data_module = MnistDataModule()
trainer.fit(model, data_module)
Test the Model
Call trainer.test
to test the model.
trainer.test(model, data_module)
Don't forget to call wandb.finish
to finish the wandb run.
Conclusion
After running the code, you will find the training process is more organized and convenient. The pytorch lightning provides a more systematic way to train the model.
In addition, you may find out that training is already faster than before. This is because the pytorch lightning has integrated many accelerators, such as the GPU, TPU, and so on, and they usually outperform the pytorch.
Callbacks
After presenting the basic pipeline of pytorch lighting. Now, we will introduce the most important concept in pytorch lightning, the callbacks.
To use callbacks, use fill the callbacks
parameter in the Trainer
.
trainer = pl.Trainer(
logger=logger,
max_epochs=5,
callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
)
The callbacks are used to customize the training process. Here, this part will familiarize you with the most commonly used built-in callbacks and custom callbacks.
Built-in Callbacks
ModelCheckpoint
The ModelCheckpoint
callback is used to save the model during training. It can save the best model, the last model, or the model at a specific epoch.
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='val_loss',
dirpath='./checkpoints',
filename='mnist-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min',
)
EarlyStopping
The EarlyStopping
callback is used to stop the training when the monitored metric stops improving.
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
mode='min',
)
LearningRateMonitor
The LearningRateMonitor
callback is used to log the learning rate during training.
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
Custom Callbacks
You can also define custom callbacks by subclassing pl.Callback
.
class CustomCallback(pl.Callback):
def on_train_start(self, trainer, pl_module):
print('Training is started!')
def on_train_end(self, trainer, pl_module):
print('Training is ended!')
Conclusion
This part has introduced the pytorch lightning, a comprehensive solution built upon pytorch to provide more convenient and efficient way to train deep learning models.
Following parts of the series will use pytorch lighting as the major framework, then introduce multiple libraries for accelerating the training process, starting with deepspeed in the next part.