Training Monitor with Wandb

Introduction

Before actually introducing any specific acceleration tools, we first presents a tool to monitor the training process.

The most artless way to do so is to print, which is inefficient, slow, bad for the eyes and not very informative. Tensor Board, the monitor tool for TensorFlow, is a good choice, but it is not very convenient to use with PyTorch, albeit there is a PyTorch version of it.

Currently, wandb, aka weights and bias, works the best for PyTorch. It is easy to use, flexible, informative, and has a good UI. It is also free to use, which is a big plus. However, the downside is that it requires an account, which is a cost for being a centralized platform.

Construct a Model

This project uses a model for MNIST classification.

class MnistModel(nn.Module):

    def __init__(
        self,
        x_len: int = 28,
    ):
        # Convulsion + Attention
        super(MnistModel, self).__init__()
        # [batch, channel, x_len, x_len] -> [batch, channel, x_len, x_len]
        self.conv = nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1
        )
        self.relu = nn.ReLU()
        # [batch, channel, x_len, x_len] -> [batch, channel, x_len/2, x_len/2]
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        # [batch, channel, x_len/2, x_len/2] -> [batch, channel, x_len/4, x_len/4]
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [batch, channel, l, l] -> [batch, channel, l * l]
        self.flatten_before_attn = nn.Flatten(start_dim=-2)
        # [batch, channel, l * l] -> [batch, channel, l * l]
        self.attn = nn.MultiheadAttention(
            embed_dim=(x_len // 4) ** 2, num_heads=1, batch_first=True
        )
        # [batch, channel, l * l] -> [batch, channel * l * l]
        self.flatten_after_attn = nn.Flatten(start_dim=-2)
        flattened_last_dim = 64 * (x_len // 4) ** 2
        self.fc1 = nn.Linear(flattened_last_dim, flattened_last_dim * 2)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(flattened_last_dim * 2, flattened_last_dim)
        self.activation2 = nn.ReLU()
        self.fc3 = nn.Linear(flattened_last_dim * 2, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.flatten_before_attn(x)
        x, _ = self.attn(x, x, x)
        x = self.flatten_after_attn(x)
        x = self.fc1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        return x

Working with Wandb

Before Training

Before training, you should install wandb, and login with either command wandb login or calling wandb.login() in the code. This only needs to be done once per computer.

First, use the init function to create a run. In wandb, you can create multiple projects, each project containing several runs. Adding tags is also supported for searching.

It is suggested that you put the hyper parameters in the wandb configuration to view it on the panel.

wandb.init(
    project="demo",
    name="mnist-demo",
    tags=["demo"],
    config={
        "lr": 1e-4,
        "epoch": 3,
        "batch_size": 64,
        "weight_decay": 1e-5
    }
)

In Training

Casual Logging

In training, just log anything you'd like to check with wandb.log, loss especially.

loss_fn = nn.CrossEntropyLoss()
for ep in range(wandb.config.epoch):
    for idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optim.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optim.step()
        wandb.log({"loss": loss.item(), "epoch": ep, "iter": idx})

Then you can spontaneously view the infos on the wandb panel. By click the project, the run, a pleasantly arranged board will be available like this,

Wandb Dashboard

Below will introduce some commonly used logging techniques. However, that's not all of them.

Logging Plot and Table

Or you can specify the plot type.

To create a plot, you must first create a wandb table.

data = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=data, columns=["x", "y"])

You can also use pandas data frame instead of list.

In addition, table can also contain other types of data, like images. You can use Image function to create that and view later on the dash board.

Then a plot can be created from the table, like,

wandb.log(
    {
        "custom_plot_id": wandb.plot.line(
            table, "x", "y", title="Custom Y vs X Line Plot"
        )
    }
)

Of course, you can directly log a table,

wandb.log(
    {
        "custom_table_id": wandb.Table(data=data, columns=["x", "y"])
    }
)

Logging Metrics

If you do evaluation during training, you might be interested in their statistical representations, like mean, minimum value, and so on, instead of solely a plot.

You can define a metric like,

wandb.define_metric("loss", summary="min")

Then just log as usual, and you will also have their respective statistic numbers. However, the key must be the same.

Logging Model

In addition, you can also log model for checkpoints,

torch.save(model.state_dict(), f"model_{ep}.pth")
wandb.log_model(path=f"model_{ep}.pth", name=f"model_{ep}")

Later you can use it as,

run = wandb.init(project="project_name")
downloaded_model_path = run.use_model(name="model_name")

Logging Files

Wandb can log Files, especially datasets with Artifact logging, just use it as,

iris_table_artifact = wandb.Artifact("iris_artifact", type="dataset")
iris_table_artifact.add(iris_table, "iris_table")
iris_table_artifact.add_file("iris.csv")
run = wandb.init(project="tables-walkthrough")
run.log_artifact(iris_table_artifact)

Although directly logging a table is allowed, logging as artifact has an increased row capacity.

After training

After training, you'd always want testing. If testing metrics is desired to be stored, you'd typically put that in the wandb summary, like,

wandb.summary["acc"] = correct / total

After everything is done, call wandb.finish() to finish this run.

Of course, summary can be accessed anytime after the creation of wandb run. You can also put things like running time, steps per seconds to it.

Finished Dashboard

Overview

Artifacts

Charts