{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# 🔊 Training an audio classification task using Torch 🔥\n", "\n", "The examples so far are predominantly regression-based. Rockpool is designed for time-series tasks, such as audio processing. This tutorial provides a basic overview of how one can train Rockpool on a standard dataset that is more representative of real-world tasks. For this example we use the [Spiking Heidelberg Datasets](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) and Spiking Heidelberg Digits (SHD) specifically. SHD consists of 8156 samples of spoken digits between 0 and 9 in English and German by 12 different speakers, corresponding to 20 possible classes. Each sample has 700 channels and up to around a second of data." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Importing the SHD Dataset Using Tonic\n", "\n", "The creators of the SHD dataset provide a [tutorial implementation](https://github.com/fzenke/spytorch/blob/main/notebooks/SpyTorchTutorial4.ipynb) which details how to download the data and demonstrates a custom class to import the dataset. As-downloaded, the default format of the SHD dataset is HDF5 files. For simplicity, we use [Tonic](https://tonic.readthedocs.io/en/latest/index.html), a Python package which provides event-based vision and audio datasets and transformations. Tonic provides the ability to load the SHD dataset in just a few lines of code, as well as transform the input easily into a low-dimensional form. Make sure this is installed as per the Tonic webpage instructions before proceeding with this tutorial.\n" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "We can use Rockpool's :py:class:`TSEvent` class to visualise the data:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# - Imports for loading data\n", "import tonic\n", "from tonic import datasets, transforms\n", "from torch.utils.data import DataLoader\n", "import numpy as np\n", "import torch\n", "\n", "from rockpool.timeseries import TSEvent\n", "\n", "try:\n", " from rich import print\n", "except ModuleNotFoundError:\n", " pass \n", "\n", "import sys\n", "!{sys.executable} -m pip install --quiet matplotlib\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [12, 6]\n", "\n", "download_dir='./data'\n", "\n", "# - Download and import the training data. The transform ensures the data has a floating type\n", "train_data = datasets.SHD(download_dir, train=True, transform=transforms.NumpyAsType(float))\n", "\n", "train_dl = iter(DataLoader(train_data, drop_last=True, shuffle=False))" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Visualise Data\n", "events, label = next(train_dl)\n", "\n", "# Extract values for the first sample from the dataloader\n", "times=events[0,:,0]\n", "events=events[0,:,1]\n", "\n", "# Create a TSEvent object corresponding to the first sample\n", "spikes_ts = TSEvent(\n", " times=times.numpy() * 1e-6,\n", " channels=events.numpy(),\n", " t_stop=(times.max()+1) * 1e-6\n", ")\n", "spikes_ts.plot()\n", "plt.title(f'Encoded SHD Sample #1 (Class={label.item()})')\n", "# plt.xlim((0, 1.1))\n", "plt.ylim((0, 700))\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, a given sample consists of some number of events across 700 channels and some duration of time. 700 channels means a high-dimensional dataset, and so large training times and networks. For low-power applications, it is desirable to use lower dimensional input. Using Tonic, we can transform the input into a better format. We can rasterise the input to facilitate training. Note: the duration of each sample may be different, thus it is useful to pad samples to make them of uniform length. We define the following parameters. Initially, we will use an encoding dimension of 20, i.e., reduce 700 input channels to 20." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "shd_timestep = 1e-6\n", "shd_channels = 700\n", "net_channels = 20\n", "net_dt = 10e-3\n", "sample_T = 100\n", "batch_size = 256\n", "num_workers = 6" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We now create a class to rasterise the input:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class ToRaster():\n", " def __init__(self, encoding_dim, sample_T = 100):\n", " self.encoding_dim = encoding_dim\n", " self.sample_T = sample_T\n", "\n", " def __call__(self, events):\n", " # tensor has dimensions (time_steps, encoding_dim)\n", " tensor = np.zeros((events[\"t\"].max()+1, self.encoding_dim), dtype=int)\n", " np.add.at(tensor, (events[\"t\"], events[\"x\"]), 1)\n", " return tensor[:self.sample_T,:]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now we can define the transforms which we apply to the dataset to prepare it for training:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "transform = transforms.Compose([\n", " transforms.Downsample(\n", " time_factor=shd_timestep / net_dt,\n", " spatial_factor=net_channels / shd_channels\n", " ),\n", " ToRaster(net_channels, sample_T = sample_T),\n", " torch.Tensor,\n", " # transforms.ToFrame(\n", " # sensor_size=(net_channels, 1, 1), time_window=100, include_incomplete=True\n", " # ),\n", " ])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Reload the dataset with the applied transforms:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# - Use a GPU if available for faster training\n", "dev = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "device = torch.device(dev) \n", "\n", "dataloader_kwargs = dict(\n", " batch_size=128,\n", " shuffle=True,\n", " drop_last=True,\n", " pin_memory=True,\n", " collate_fn=tonic.collation.PadTensors(batch_first=True),\n", " num_workers=8,\n", ")\n", "\n", "train_data = datasets.SHD(download_dir, train=True, transform=transform)\n", "\n", "# disk_train_dataset = tonic.MemoryCachedDataset(\n", "disk_train_dataset = tonic.DiskCachedDataset(\n", " dataset=train_data,\n", " # transform = torch.Tensor,lambda x: torch.tensor(x).to_sparse(),\n", " cache_path=f\"cache/{train_data.__class__.__name__}/train/{net_channels}/{net_dt}\",\n", " # target_transform=lambda x: torch.tensor(x),\n", " reset_cache = True,\n", " )\n", " # device = device,\n", "# )\n", "\n", "train_dl = DataLoader(disk_train_dataset, **dataloader_kwargs)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "`tonic.collation.PadTensors(batch_first=True)` performs padding to ensure a consistent sample length for each sample across all batches. `ToRaster` above indicates a maximum sample length of 100 points.\n", "\n", "We can now visualise the first sample in its encoded form in the same way as before:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "events, label = disk_train_dataset[0]\n", "# events = events.to_dense().cpu().numpy()\n", "\n", "spikes_ts = TSEvent.from_raster(events.squeeze(), dt = net_dt, name = f'Encoded SHD Sample #1 (Class={label.item()})')\n", "spikes_ts.plot()\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now that our data is in the correct format, we can define our network. To begin, we use a simple network consisting of a linear layer, a leaky integrate-and-fire layer, another linear layer, and an exponential synapse output layer:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from rockpool.nn.modules import LIFTorch, LinearTorch, ExpSynTorch, LIFExodus\n", "from rockpool.nn.combinators import Sequential\n", "from rockpool.parameters import Constant\n", "\n", "# - Select a neuron model to use\n", "from rockpool.utilities.backend_management import backend_available\n", "NeuronModel = LIFExodus if backend_available('sinabs-exodus') and torch.cuda.is_available() else LIFTorch\n", "\n", "# - Network Definition\n", "def SimpleNet(Nin, Nhidden, Nout):\n", " return Sequential(\n", " LinearTorch((Nin, Nhidden), has_bias=False),\n", " NeuronModel(Nhidden,\n", " tau_mem=Constant(100e-3),\n", " tau_syn=Constant(100e-3),\n", " threshold=Constant(1.),\n", " bias=Constant(0.),\n", " dt=net_dt,\n", " has_rec=False),\n", " LinearTorch((Nhidden, Nout), has_bias = False),\n", " ExpSynTorch(Nout, dt=net_dt, tau=Constant(5e-3))\n", " )" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Note, we use the `Constant` functionality here as, by default, all parameters in Rockpool are trainable. Keeping everything except the weights constant ensures fast training for the example. We can now define our network shape:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/dylan/mina_hdd/miniconda3/envs/py38/lib/python3.8/site-packages/torch/cuda/__init__.py:497: UserWarning: Can't initialize NVML\n", " warnings.warn(\"Can't initialize NVML\")\n" ] }, { "data": { "text/html": [ "
TorchSequential  with shape (20, 20) {\n",
                            "    LinearTorch '0_LinearTorch' with shape (20, 20)\n",
                            "    LIFTorch '1_LIFTorch' with shape (20, 20)\n",
                            "    LinearTorch '2_LinearTorch' with shape (20, 20)\n",
                            "    ExpSynTorch '3_ExpSynTorch' with shape (20,)\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "TorchSequential with shape \u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " LinearTorch \u001b[32m'0_LinearTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\n", " LIFTorch \u001b[32m'1_LIFTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\n", " LinearTorch \u001b[32m'2_LinearTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m, \u001b[1;36m20\u001b[0m\u001b[1m)\u001b[0m\n", " ExpSynTorch \u001b[32m'3_ExpSynTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m20\u001b[0m,\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Build a network\n", "Nin = net_channels\n", "Nhidden = 20\n", "Nout = 20\n", "\n", "torch.manual_seed(1234) # A manual seed ensures repeatability\n", "\n", "net = SimpleNet(Nin, Nhidden, Nout).to(dev)\n", "print(net)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can test the network output makes sense here. Setting `record=True` as an argument in the network allows you to save the record dictionary of the network at each layer. Let's pass the first sample to the network and see what happens:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "events, labels = next(iter(train_dl)) # Get a batch from the train dataloader\n", "events = events.to_dense().to(device)\n", "sample = events[0,:,:] # Get the first sample from the first batch\n", "\n", "output, state, rec = net(sample, record=True) # Pass the first sample through the network" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Here, we see the network returns three things: the network output, the current network state, and the record dictionary. The output is a 3D Tensor of shape: `[batch_size*[times*[channels]]]`. It may help to visualise this as a Tensor of size `batch_size` (256 here) which contains another tensor whose size is equal to the number of time steps (100 in this case). This tensor contains 100 1x20 tensors, with each element corresponding to the number of neurons in the ExpSynTorch layer. You may have seen us index tensors already in this tutorial, e.g., in the `sample = events[0,:,:]` line above. Tensor indexing here works as follows: `tensor_name[batch_item, timestep, channel]`. For example, to access the fifth neuron at the second timestep in the eighth sample in the batch, one should type `sample = events[8,1,5]`. To select all channels, timesteps, or batch items, use `:`, e.g., to access all channels and timesteps in the first sample, type `sample = events[0,:,:]`. \n", "\n", "The network state gives us the output of each layer in the network for a given input. The record dictionary stores all states for a given batch. To improve training time and reduce memory overheads, you may wish to skip the `record=True` argument, however it can be invaluable for debugging if the network doesn't behave as expected. For example, if the network does not appear to give any output, one can use the record dictionary to determine whether a specific layer is at fault. \n", "\n", "Now that we understand what our network outputs, we can define our training process. As SHD is a classification task, we use [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss) as our loss function. We use Adam as our optimiser." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from torch.optim import Adam\n", "from torch.nn import CrossEntropyLoss\n", "\n", "# - Get the optimiser functions\n", "optimizer = Adam(net.parameters().astorch(), lr=1e-3)\n", "\n", "# - Loss function\n", "loss_fun = CrossEntropyLoss()\n", "\n", "# - Record the loss values over training iterations\n", "accuracy = []\n", "loss_t = []\n", "num_epochs = 500" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to build our training loop:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ea28b64ba574a8d93f4412a2e9f9d5f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0%| | 0/500 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Plot the Training Loss\n", "fig, ax = plt.subplots()\n", "ax.plot(loss_t, color='blue')\n", "ax.set_ylabel('CrossEntropy loss', color='blue')\n", "ax.set_yscale('log')\n", "ax.set_xlabel('Epochs')\n", "\n", "ax2 = ax.twinx()\n", "ax2.plot(accuracy, color='orange')\n", "ax2.plot(ax2.get_xlim(), [100/20, 100/20], '--', color='orange')\n", "ax2.set_ylabel('Accuracy (%)', color='orange')\n", "ax2.set_yscale('linear')\n", "\n", "plt.title('Rockpool SHD training loss and accuracy')\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As the plot shows, our network successfully trains on the data! Feel free to experiment with different values for the learning rate, number of epochs and network parameters to see how the different parameters affect the network performance. Here we train for a limited number of epochs, but generally several hundred are required for convergence. You can re-run the cell above to train for longer periods.\n", "\n", "The best values for the parameters will depend on the dataset. For example, here if `tau_mem` or `tau_syn` are too small (below about 0.02 for `threshold = 1.`), then the LIF layer will not fire and the network will not train. Depending on your application, it may be necessary to trial different values. As we note above, looking at the record dictionary can be very useful in figuring out what's happening under the hood.\n", "\n", "It's worthwhile here to discuss the nature of the training loop. For a given sample or batch of samples that the network takes as input the network output is as described above, i.e., a tensor with the ExpSyn layer outputs for each batch and at each timestep. We want to integrate the output for each neuron over time, i.e., the synaptic current, `isyn` and pass this to the CrossEntropyLoss function. We do perform this integration by taking the cumulative sum of the synaptic currents for each channel using `sum = torch.cumsum(output, dim=1)`. The loss function expects a Tensor of size [batch_size*[channels]] and so we need to take the value for the last timestep, which we do by indexing `sum[:,-1,:]`. As we note above, the SHD dataset has 20 possible output classes. Our network's 20 output neurons correspond to a class, and so our prediction is the specific neuron with the largest synaptic current. We can easily find this by taking the argmax of the output neurons at the last timestep. This allows us to calculate the accuracy as above.\n", "\n", "Let's see how the network fares on data it hasn't seen before. First we load the test data and transform it in the same way as for the training data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_data = datasets.SHD(download_dir, train=False, transform=transform)\n", "\n", "test_dl = DataLoader(test_data, num_workers=num_workers, batch_size=batch_size, \n", " collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can now build our test loop to see how the network performs on the validation set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Test Accuracy: 66.748%\n",
                            "
\n" ], "text/plain": [ "Test Accuracy: \u001b[1;36m66.748\u001b[0m%\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Test loop:\n", "net.eval()\n", "with torch.no_grad():\n", " correct = 0\n", " total = 0\n", " total_loss = 0\n", "\n", " for events, labels in test_dl:\n", " events, labels = events.to(dev), labels.to(dev)\n", " output, _, _ = net(torch.Tensor(events).float())\n", "\n", " sum = torch.cumsum(output, dim=1)\n", " \n", " predicted = torch.argmax(sum[:,-1,:], 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " \n", " accuracy_test = (correct/total)*100\n", "\n", "print(f\"Test Accuracy: {accuracy_test:.3f}%\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As we see, the network doesn't perform as well on the test data, but by playing around with the different parameters, as well as introducing techniques such as Dropout or regularisation, one should be able to obtain good performance. The authors provide a [leaderboard](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) for the best-performing networks, which, at the time of writing, has 48.1% in 6th place, and 91.1% in first place." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('py38')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "cb88bc135d05a2341e3cb126d78f83330a5c316ea17e1399798ae369290a3c17" } } }, "nbformat": 4, "nbformat_minor": 2 }