Tutorial 3: Asynchronous Training and Callbacks¶
Beyond the basics of training and predict, our API introduces 2 major functionalities: Asynchronicity for your training/prediction and Callbacks.
We will train a pretrained resnet18 model from our global models and chain multiple training pipelines together.
import npu import npu.vision.models as models import npu.vision.datasets as dset npu.api(API_TOKEN) # Samples used for training and validation SAMPLES = 5000 # Validation samples starting point VAL = 30000 for i in range(5): min_range = i * SAMPLES max_range = (i + 1) * SAMPLES trained_model = npu.train(models.resnet18(pretrained=True), train_data=dset.CIFAR10[min_range:max_range], val_data=dset.CIFAR10[min_range+VAL:max_range+VAL], loss=npu.loss.SparseCrossEntropyLoss, optim=npu.optim.SGD(lr=0.01), batch_size=128, epochs=3, asynchronous=True)
Token successfully authenticated Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982236bbaecaba3d6a10 Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982236bbaecaba3d6a11 Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982336bbaecaba3d6a12 Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982336bbaecaba3d6a13 Started training. View status at https://dashboard.neuro-ai.co.uk/tasks?task_id=5ee7982336bbaecaba3d6a14
We can view on the dashboard that all of our training tasks are running concurrently. This means we can test a variety of different datasets and hyperparameters without having to wait for each one to finish before trying something different. We can see below that each of the 5 training tasks are running.
We can also chain models trained previously in a similarly asynchronous matter.
model = models.resnet18(pretrained=True) for i in range(5): model = npu.train(model, train_data=datasets.CIFAR10.train, val_data=datasets.CIFAR10.val, batch_size=128, epochs=3, asynchronous=True)
This workflow is similar to running your own epoch cycle. You can similarly view the progress of your task online, and subsequent models which rely on the previous model being completed will simply run immediately after the previous one has ran. You can train other models while the sequential models are training, or use your intermediarily trained models to perform inference until your latest model has been trained.
current_model = model = models.resnet18(pretrained=True) def updateModel(model): global current_model current_model = model for i in range(2): model_trained = npu.train(model, train_data=dset.CIFAR10.train, val_data=dset.CIFAR10.val, loss=npu.loss.SparseCrossEntropyLoss, optim=npu.optim.SGD(lr=0.01), batch_size=128, epochs=3, asynchronous=True, callback=updateModel)
This will replace the current model being used for the trained one upon completion, but still allow you to use the use the previously trained models until the new one has been trained.
We will now predict and visualise the result with the latest trained model. Let’s first get some samples from your dataset.
import torchvision.datasets as dset import os from pathlib import Path import matplotlib.pyplot as plt import numpy as np # Create folder to save Datasets if it doesn't exist already CWD = os.getcwd() DATA_PATH = CWD + '/datasets' Path(DATA_PATH).mkdir(parents=True, exist_ok=True) test_ds = dset.CIFAR10(DATA_PATH, train=False, download=True) samples = 10 # Get 10 samples and reshape them to [N,1,C,H,W] x = test_ds.data[0:samples].transpose((0, 3, 2, 1))[:,np.newaxis,:]
Now we can run prediction with the NPU API.
preds =  for i in range(samples): pred = npu.predict(model=model_trained, data=x[i]).get_result() preds.append(np.argmax(pred))
And visualise the results with matplotlib as usual,
images = test_ds.data[0:samples] num_row = 2 num_col = 5 # plot images fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row)) for i in range(max): ax = axes[i//num_col, i%num_col] ax.imshow(images[i]) ax.set_title(test_ds.classes[preds[i]]) plt.tight_layout() plt.show()