Tutorial 3: Asynchronous Training and Callbacks

Beyond the basics of training and predict, our API introduces 2 major functionalities: Asynchronicity for your training/prediction and Callbacks.

Asynchronous functionality

We will train a pretrained resnet18 model from our global models and chain multiple training pipelines together.

import npu
import npu.vision.models as models
import npu.vision.datasets as dset

npu.api(API_TOKEN)

# Samples used for training and validation
SAMPLES = 5000
# Validation samples starting point
VAL = 30000
for i in range(5):
    min_range = i * SAMPLES
    max_range = (i + 1) * SAMPLES
    trained_model = npu.train(models.resnet18(pretrained=True),
                              train_data=dset.CIFAR10[min_range:max_range],
                              val_data=dset.CIFAR10[min_range+VAL:max_range+VAL],
                              loss=npu.loss.SparseCrossEntropyLoss,
                              optim=npu.optim.SGD(lr=0.01),
                              batch_size=128,
                              epochs=3,
                              asynchronous=True)

Out:

Token successfully authenticated
Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982236bbaecaba3d6a10
Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982236bbaecaba3d6a11
Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982336bbaecaba3d6a12
Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982336bbaecaba3d6a13
Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982336bbaecaba3d6a14

We can view on the dashboard that all of our training tasks are running concurrently. This means we can test a variety of different datasets and hyperparameters without having to wait for each one to finish before trying something different. We can see below that each of the 5 training tasks are running.

../_images/multi.png

We can also chain models trained previously in a similarly asynchronous matter.

model = models.resnet18(pretrained=True)
for i in range(5):
    model = npu.train(model,
                      train_data=datasets.CIFAR10.train,
                      val_data=datasets.CIFAR10.val,
                      batch_size=128,
                      epochs=3,
                      asynchronous=True)

Callbacks

This workflow is similar to running your own epoch cycle. You can similarly view the progress of your task online, and subsequent models which rely on the previous model being completed will simply run immediately after the previous one has ran. You can train other models while the sequential models are training, or use your intermediarily trained models to perform inference until your latest model has been trained.

current_model = model = models.resnet18(pretrained=True)

def updateModel(model):
global current_model
current_model = model

for i in range(2):
model_trained = npu.train(model,
                          train_data=dset.CIFAR10.train,
                          val_data=dset.CIFAR10.val,
                          loss=npu.loss.SparseCrossEntropyLoss,
                          optim=npu.optim.SGD(lr=0.01),
                          batch_size=128,
                          epochs=3,
                          asynchronous=True,
                          callback=updateModel)

This will replace the current model being used for the trained one upon completion, but still allow you to use the use the previously trained models until the new one has been trained.

We will now predict and visualise the result with the latest trained model. Let’s first get some samples from your dataset.

import torchvision.datasets as dset
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

# Create folder to save Datasets if it doesn't exist already
CWD = os.getcwd()
DATA_PATH = CWD + '/datasets'
Path(DATA_PATH).mkdir(parents=True, exist_ok=True)

test_ds = dset.CIFAR10(DATA_PATH, train=False, download=True)

samples = 10
# Get 10 samples and reshape them to [N,1,C,H,W]
x = test_ds.data[0:samples].transpose((0, 3, 2, 1))[:,np.newaxis,:]

Now we can run prediction with the NPU API.

preds = []
for i in range(samples):
    pred = npu.predict(model=model_trained, data=x[i]).get_result()
    preds.append(np.argmax(pred))

And visualise the results with matplotlib as usual,

images = test_ds.data[0:samples]

num_row = 2
num_col = 5
# plot images
fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))
for i in range(max):
    ax = axes[i//num_col, i%num_col]
    ax.imshow(images[i])
    ax.set_title(test_ds.classes[preds[i]])
plt.tight_layout()
plt.show()

Out:

../_images/tut3_grid.png