{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training a Classifier\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borchero/pyblaze/blob/master/docs/examples/classifier.ipynb)\n", "[![Download Jupyter Notebbok](https://img.shields.io/badge/Github-Download-brightgreen)](https://github.com/borchero/pyblaze/blob/master/docs/examples/classifier.ipynb)\n", "\n", "Training a classifier on images is one of the most prominent examples showcasing what a library can do. Inspired by PyTorch’s tutorials, we want to train a convolutional neural network on the CIFAR10 dataset using the features of PyBlaze.\n", "\n", "At first, we import all libraries that we’re going to use throughout this example:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "import pyblaze.nn as xnn\n", "import pyblaze.nn.functional as X\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading the Data\n", "\n", "At the beginning, we load our data conveniently using torchvision. PyBlaze does not come into play yet." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "train_val_dataset = torchvision.datasets.CIFAR10(\n", " root=\"~/Downloads/\", train=True, download=True, transform=transforms.Compose([\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor()\n", " ])\n", ")\n", "test_dataset = torchvision.datasets.CIFAR10(\n", " root=\"~/Downloads/\", train=False, download=True, transform=transforms.ToTensor()\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initializing Data Loaders\n", "\n", "First, we set aside 20% of all training data for validation. Usually, you would need to compute the sizes of the resulting subsets and then split the dataset randomly. However, using PyBlaze, you can do this more conveniently.\n", "\n", "By simply importing `pyblaze.nn`, datasets receive an additional function `random_split`. This accepts arbitrarily many floating point numbers indicating the fraction of the dataset to be randomly sampled into subsets. Note that these numbers need to add to 1." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "train_dataset, val_dataset = train_val_dataset.random_split(0.8, 0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can initialize the data loaders. Normally, you would initialize a data loader and pass the dataset to the initializer. Again, as soon as you import `pyblaze.nn`, we extend PyTorch’s native dataset with a `loader` method. The method creates a data loader while its parameters are the same as for PyTorch’s data loader initializer." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_loader = train_dataset.loader(batch_size=256, num_workers=4, shuffle=True)\n", "val_loader = val_dataset.loader(batch_size=2048)\n", "test_loader = test_dataset.loader(batch_size=2048)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we set the batch size for validation and testing significantly higher than for training: as we are running the model in eval mode, no gradients need to be stored and much less memory is required." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the Model\n", "\n", "As a model, we define a common convolutional neural network. Most importantly, defining a model for PyBlaze is no different than defining a model natively in PyTorch. However, PyBlaze provides additional functionality to make it easier working with models.\n", "\n", "A very useful layer that is introduced by PyBlaze is the `xnn.View` layer. While it is usually required to reshape the input in the `forward` method when using convolutional layers, `xnn.View` enables wrapping convolutional layers and linear layers into a single `nn.Sequential` module." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " self.seq = nn.Sequential(\n", " ConvNet(3, 32),\n", " ConvNet(32, 64),\n", " ConvNet(64, 128),\n", " xnn.View(-1, 2048),\n", " nn.Linear(2048, 10)\n", " )\n", "\n", " def forward(self, x):\n", " return self.seq(x)\n", "\n", "\n", "class ConvNet(nn.Module):\n", " \n", " def __init__(self, in_channels, out_channels):\n", " super().__init__()\n", " \n", " self.conv = nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU(),\n", " nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Dropout(0.25),\n", " )\n", " \n", " def forward(self, x):\n", " return self.conv(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initializing a model is as simple as for pure PyTorch modules:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model = Model()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model has 308,394 parameters.\n" ] } ], "source": [ "print(f\"Model has {sum(p.numel() for p in model.parameters()):,} parameters.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model training and evaluation is the core feature of PyBlaze. The code below trains the model according to the following constraints:\n", "\n", "* We train with the data from `train_loader` and evaluate the performance after every epoch with data from `val_loader`.\n", "* We train for 150 epochs (max), evaluate every 5 epochs and use early stopping with a patience of 5 evaluation steps. We simply use the default and track the validation loss.\n", "* We use `Adam` with its default parameters as optimizer and minimize the cross entropy loss.\n", "* We log the progress of each batch to the command line.\n", "* We compute the accuracy of the predictions of the validation data after every epoch.\n", "\n", "The result of this call is a history object which aggregates information about the training. This includes train losses after every batch as well as epoch. Further, it includes validation losses and validation metrics after every epoch." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "engine = xnn.MLEEngine(model)\n", "optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n", "loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.36 it/s] loss: 1.45220, val_accuracy: 0.50680, val_loss: 1.48204\n", "Epoch 2/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.47 it/s] loss: 0.99612 \n", "Epoch 3/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.81 it/s] loss: 0.83032 \n", "Epoch 4/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.64 it/s] loss: 0.74560 \n", "Epoch 5/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.70 it/s] loss: 0.68434 \n", "Epoch 6/150:\n", "\u001b[2K [Elapsed 0:00:04 | 32.64 it/s] loss: 0.64454, val_accuracy: 0.77590, val_loss: 0.64798\n", "Epoch 7/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.65 it/s] loss: 0.60356 \n", "Epoch 8/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.30 it/s] loss: 0.56721 \n", "Epoch 9/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.20 it/s] loss: 0.54744 \n", "Epoch 10/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.16 it/s] loss: 0.52634 \n", "Epoch 11/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.72 it/s] loss: 0.50366, val_accuracy: 0.81740, val_loss: 0.53041\n", "Epoch 12/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.98 it/s] loss: 0.47945 \n", "Epoch 13/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.96 it/s] loss: 0.46435 \n", "Epoch 14/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.49 it/s] loss: 0.45264 \n", "Epoch 15/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.38 it/s] loss: 0.43764 \n", "Epoch 16/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.38 it/s] loss: 0.42418, val_accuracy: 0.83790, val_loss: 0.48021\n", "Epoch 17/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.63 it/s] loss: 0.41001 \n", "Epoch 18/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.33 it/s] loss: 0.39783 \n", "Epoch 19/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.42 it/s] loss: 0.38518 \n", "Epoch 20/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.70 it/s] loss: 0.37354 \n", "Epoch 21/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.07 it/s] loss: 0.36279, val_accuracy: 0.82470, val_loss: 0.51583\n", "Epoch 22/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.64 it/s] loss: 0.35568 \n", "Epoch 23/150:\n", "\u001b[2K [Elapsed 0:00:02 | 52.38 it/s] loss: 0.34712 \n", "Epoch 24/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.18 it/s] loss: 0.33781 \n", "Epoch 25/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.23 it/s] loss: 0.33789 \n", "Epoch 26/150:\n", "\u001b[2K [Elapsed 0:00:04 | 35.12 it/s] loss: 0.32599, val_accuracy: 0.85860, val_loss: 0.40353\n", "Epoch 27/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.04 it/s] loss: 0.31310 \n", "Epoch 28/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.22 it/s] loss: 0.31169 \n", "Epoch 29/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.33 it/s] loss: 0.30332 \n", "Epoch 30/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.02 it/s] loss: 0.29768 \n", "Epoch 31/150:\n", "\u001b[2K [Elapsed 0:00:04 | 32.99 it/s] loss: 0.29457, val_accuracy: 0.86270, val_loss: 0.40734\n", "Epoch 32/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.05 it/s] loss: 0.28917 \n", "Epoch 33/150:\n", "\u001b[2K [Elapsed 0:00:03 | 48.90 it/s] loss: 0.28505 \n", "Epoch 34/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.97 it/s] loss: 0.27621 \n", "Epoch 35/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.19 it/s] loss: 0.26834 \n", "Epoch 36/150:\n", "\u001b[2K [Elapsed 0:00:04 | 33.15 it/s] loss: 0.26322, val_accuracy: 0.85380, val_loss: 0.44287\n", "Epoch 37/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.53 it/s] loss: 0.25915 \n", "Epoch 38/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.88 it/s] loss: 0.25473 \n", "Epoch 39/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.75 it/s] loss: 0.26057 \n", "Epoch 40/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.71 it/s] loss: 0.24920 \n", "Epoch 41/150:\n", "\u001b[2K [Elapsed 0:00:04 | 32.96 it/s] loss: 0.24499, val_accuracy: 0.85930, val_loss: 0.43536\n", "Epoch 42/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.17 it/s] loss: 0.24057 \n", "Epoch 43/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.87 it/s] loss: 0.23604 \n", "Epoch 44/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.20 it/s] loss: 0.23447 \n", "Epoch 45/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.96 it/s] loss: 0.23088 \n", "Epoch 46/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.57 it/s] loss: 0.22833, val_accuracy: 0.86980, val_loss: 0.40272\n", "Epoch 47/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.16 it/s] loss: 0.22665 \n", "Epoch 48/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.92 it/s] loss: 0.22133 \n", "Epoch 49/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.30 it/s] loss: 0.22232 \n", "Epoch 50/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.13 it/s] loss: 0.22184 \n", "Epoch 51/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.80 it/s] loss: 0.21760, val_accuracy: 0.86400, val_loss: 0.42838\n", "Epoch 52/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.69 it/s] loss: 0.21549 \n", "Epoch 53/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.43 it/s] loss: 0.21098 \n", "Epoch 54/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.40 it/s] loss: 0.20790 \n", "Epoch 55/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.09 it/s] loss: 0.20619 \n", "Epoch 56/150:\n", "\u001b[2K [Elapsed 0:00:04 | 32.51 it/s] loss: 0.20051, val_accuracy: 0.87510, val_loss: 0.40508\n", "Epoch 57/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.84 it/s] loss: 0.19795 \n", "Epoch 58/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.18 it/s] loss: 0.19924 \n", "Epoch 59/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.01 it/s] loss: 0.19628 \n", "Epoch 60/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.69 it/s] loss: 0.19570 \n", "Epoch 61/150:\n", "\u001b[2K [Elapsed 0:00:04 | 33.88 it/s] loss: 0.19398, val_accuracy: 0.88240, val_loss: 0.38243\n", "Epoch 62/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.92 it/s] loss: 0.19323 \n", "Epoch 63/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.32 it/s] loss: 0.19755 \n", "Epoch 64/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.32 it/s] loss: 0.19354 \n", "Epoch 65/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.20 it/s] loss: 0.18593 \n", "Epoch 66/150:\n", "\u001b[2K [Elapsed 0:00:04 | 33.16 it/s] loss: 0.18705, val_accuracy: 0.86990, val_loss: 0.40923\n", "Epoch 67/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.13 it/s] loss: 0.17992 \n", "Epoch 68/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.34 it/s] loss: 0.18269 \n", "Epoch 69/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.79 it/s] loss: 0.18332 \n", "Epoch 70/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.06 it/s] loss: 0.18073 \n", "Epoch 71/150:\n", "\u001b[2K [Elapsed 0:00:04 | 34.35 it/s] loss: 0.18015, val_accuracy: 0.87790, val_loss: 0.39438\n", "Epoch 72/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.55 it/s] loss: 0.17447 \n", "Epoch 73/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.53 it/s] loss: 0.17418 \n", "Epoch 74/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.42 it/s] loss: 0.17548 \n", "Epoch 75/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.86 it/s] loss: 0.17303 \n", "Epoch 76/150:\n", "\u001b[2K [Elapsed 0:00:04 | 33.72 it/s] loss: 0.17062, val_accuracy: 0.85620, val_loss: 0.49162\n", "Epoch 77/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.73 it/s] loss: 0.17179 \n", "Epoch 78/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.45 it/s] loss: 0.17274 \n", "Epoch 79/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.71 it/s] loss: 0.17069 \n", "Epoch 80/150:\n", "\u001b[2K [Elapsed 0:00:03 | 51.87 it/s] loss: 0.16697 \n", "Epoch 81/150:\n", "\u001b[2K [Elapsed 0:00:04 | 33.46 it/s] loss: 0.16554, val_accuracy: 0.87900, val_loss: 0.40666\n", "Epoch 82/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.45 it/s] loss: 0.16777 \n", "Epoch 83/150:\n", "\u001b[2K [Elapsed 0:00:03 | 49.40 it/s] loss: 0.16606 \n", "Epoch 84/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.07 it/s] loss: 0.16360 \n", "Epoch 85/150:\n", "\u001b[2K [Elapsed 0:00:03 | 50.44 it/s] loss: 0.16609 \n", "Epoch 86/150:\n", "\u001b[2K [Elapsed 0:00:04 | 33.82 it/s] loss: 0.16356, val_accuracy: 0.86780, val_loss: 0.44691\n", "Early stopping after epoch 86 (patience 5).\n" ] } ], "source": [ "history = engine.train(\n", " train_loader,\n", " val_data=val_loader,\n", " epochs=150,\n", " eval_every=5,\n", " optimizer=optimizer,\n", " loss=nn.CrossEntropyLoss(),\n", " callbacks=[\n", " xnn.BatchProgressLogger(),\n", " xnn.EarlyStopping(patience=5)\n", " ],\n", " metrics={\n", " 'accuracy': X.accuracy\n", " }\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plotting the Losses\n", "\n", "With the information from the history object, we can plot the progress of our training. The history object always provides `batch_loss` summarizing the training losses after each batch as well as `loss` as the train losses after each epoch. Depending on additional parameters passed to the fit function, additional keys are available." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "all_losses = history.batch_loss\n", "plt.figure(dpi=150)\n", "plt.plot(range(len(all_losses)), all_losses)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our case, we used validation and therefore there exists a `val_loss` property on the history object. Theoretically, we would also be able to plot `val_accuracy` as all metrics are recorded in the history object as well." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(dpi=150)\n", "plt.plot(range(len(history.loss)), history.loss, label='Training Loss')\n", "plt.plot(np.array(range(len(history.val_loss))) * 5, history.val_loss, label='Validation Loss')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, the losses are starting to diverge and the generalization gap becomes larger starting from epoch 25. However, the lowest validation loss was obtained after epoch 61." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluating the Model\n", "\n", "Lastly, we want to evaluate the performance of our model. For this, we use our test data and call `evaluate` on the model. We are interested only in a single metric, the accuracy.\n", "\n", "The returned value is a dictionary which provides the metrics that were recorded. In our case, the only metric is `accuracy`." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K [Elapsed 0:00:01 | 3.16 it/s] \n" ] } ], "source": [ "evaluation = engine.evaluate(\n", " test_loader,\n", " callbacks=[\n", " xnn.PredictionProgressLogger()\n", " ],\n", " metrics={\n", " 'accuracy': X.accuracy\n", " }\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our model achieves an accuracy of 86.46%.\n" ] } ], "source": [ "print(f\"Our model achieves an accuracy of {evaluation['accuracy']:.2%}.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using GPUs\n", "\n", "As PyBlaze is a framework dedicated for large-scale machine learning, it has first-class support for GPUs. In fact, in order to run training and evaluation on a GPU, **you do not have to do anything**. In fact, if you had a GPU at your disposal when running this tutorial, you already used it. Generally, PyBlaze will use all available GPUs automatically, speeding up training as much as possible.\n", "\n", "If you have multiple GPUs and want to use a specific one, just pass `gpu=` to any of the functions above. Likewise, you can select a subset of GPUs by providing `gpu=[, , …]` or no GPU at all by using `gpu=False`.\n", "\n", "A special case is `gpu=True` which chooses a single GPU: the one with the most amount of free memory." ] } ], "metadata": { "kernelspec": { "display_name": "epn", "language": "python", "name": "epn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }