This page was generated from docs/basics/hello_MNIST.ipynb. Interactive online version: Binder badge

👋 Hello MNIST with Rockpool

Rockpool is a neural network training library, with a focus on spiking neural networks (SNNs) and other neurons that have internal state and temporal dynamics.

The goal is to make training and deploying SNNs as simple as training a standard DNN. Here we show the steps to train the MNIST digit classification task — a common dataset for getting started with neural network libraries.

[9]:
# - Install required packages
import sys
!{sys.executable} -m pip install --quiet rockpool tonic tqdm torch torchvision matplotlib
Note: you may need to restart the kernel to use updated packages.
[10]:
# - Basic imports
import torch
import torchvision

import numpy as np

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = [9.6, 3.6]
plt.rcParams['font.size'] = 12

from tqdm.autonotebook import tqdm, trange

from IPython.display import Image

Spiking neurons

The main difference between a spiking neuron (SN) and a standard artificial neuron (AN) is in their understanding of time. Standard ANs operate instantaneously, by simply summing their weighted inputs and applying a transfer function such as a ReLU. Spiking neurons, on the other hand, have an internal state that evolves over time in response to their input.

As a result, SNNs are great for processing temporal signals — and as a consequence, we need to consider how to provide input data to an SNN over time, and consider e.g. the time constants \(\tau\) as additional network parameters.

In Rockpool, we have a simple Leaky Integrate-and-Fire (LIF) spiking neuron, defined by the module LIF. The cell below shows you how to build a single spiking neuron, provide input events, simulate the neuron and examine its internal state.

You can read more in depth about spiking neurons in A simple introduction to spiking neurons.

[11]:
# - Import the SNN modules from rockpool
from rockpool.nn.modules import LIF

# - Generate a spiking neuron
neuron = LIF(1)

# - Simulate this neuron for 1 sec with poisson spiking input z(t)
num_timesteps = int(1/neuron.dt)
input_z = 0.06 * (np.random.rand(num_timesteps) < 0.0125)
output, _, rec_dict = neuron(input_z, record = True)

# - Display the input, internal state and output events
plt.figure()
plt.plot(rec_dict['isyn'].squeeze(), label = '$I_s(t)$')
plt.plot(rec_dict['vmem'].squeeze(), label = '$V_m(t)$')
b, t, n = np.nonzero(output)
plt.scatter(t, n, marker='|', c = 'k', label = '$o(t)$')
plt.plot([0, num_timesteps], [neuron.threshold] * 2, 'k:', label='$\Theta$')
plt.xlabel('Time (dt)')
plt.ylabel('$V_m$, $I_s$ (a.u.)')
plt.legend();

../_images/basics_hello_MNIST_5_0.png

Data and encoding

Now we can load the MNIST dataset, and decide how to encode the data for processing by an SNN. We make use of the torchvision package to obtain the MNIST dataset and manage the dataset and data loader classes.

The data will be accessed as torch.Tensor objects.

[46]:
# - Number of samples per batch
batch_size = 256

# - Download and access the MNIST training dataset
train_data = torchvision.datasets.MNIST(
    root=".",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

# - Create a data loader for the training dataset
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True
)

# - Create a test dataset
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        root=".",
        train=False,
        transform=torchvision.transforms.ToTensor(),
    ),
    batch_size=batch_size,
)

Now we need to decide how to encode the data samples for the SNN. The MNIST samples are \(28 \times 28\) resolution images, with no temporal component. We will encode the images by arranging the \(28 \times 28\) pixels into a 784-element vector, and creating 784 spiking input channels. We’ll use a Poisson process to generate an average input rate for each channel, according to the intensity of the corresponding pixel. Pixels with value 0 will have generate no input events; pixels with value 1 will generate the highest rate of events. To do so we need to define how many time-steps each sample will take, and what the duration of a single time-step will be.

[13]:
# - Define the temporal aspects of a data sample
num_timesteps = 100
dt = 10e-3

# - Extract the number of classes and input channels
num_classes = len(torchvision.datasets.MNIST.classes)
input_channels = train_data[0][0].numel()

# - Define a function to encode an input into a poisson event series
def encode_poisson(data: torch.Tensor, num_timesteps: int, scale: float = 0.1) -> torch.Tensor:
  num_batches, frame_x, frame_y = data.shape
  data = scale * data.view((num_batches, 1, -1)).repeat((1, num_timesteps, 1))
  return (torch.rand(data.shape) < (data * scale)).float()

We also need to determine what the target output of the network will look like. The MNIST dataset has 10 classes. We’ll build a network with 10 output neurons, one for each class, and train the network to produce a high event rate for the target class and no events for the non-target classes. Our target will also be a time series of events, with the same duration as the input time series, and with 10 channels.

[14]:
# - Define a function to encode the network target
def encode_class(class_idx: torch.Tensor, num_classes: int, num_timesteps: int) -> torch.Tensor:
  num_batches = class_idx.numel()
  target = torch.nn.functional.one_hot(class_idx, num_classes = num_classes)
  return target.view((num_batches, 1, -1)).repeat((1, num_timesteps, 1)).float()
[15]:
# - Get one sample
frame, class_idx = train_data[0]

# - Encode the input and targets
data = encode_poisson(frame, num_timesteps)
target = encode_class(torch.tensor(class_idx), num_classes, num_timesteps)

# - Plot the poisson input for this sample
plt.figure()
b, t, n = torch.nonzero(data, as_tuple = True)
plt.scatter(t * dt, n, marker = '|')
plt.xlabel('Time (s)')
plt.ylabel('Channel')
plt.title('Input')

# - Plot the target event series for this sample
b, t, n = torch.nonzero(target, as_tuple = True)
plt.figure()
plt.scatter(t * dt, n, marker = '|')
plt.ylim([-1, num_classes+1])
plt.xlabel('Time (s)')
plt.ylabel('Channel')
plt.title('Target');
../_images/basics_hello_MNIST_12_0.png
../_images/basics_hello_MNIST_12_1.png

Building a network

Rockpool allows you to build SNNs with a simple syntax, similar to PyTorch. We support several computational back-ends in Rockpool, one of which is PyTorch, which we will use here.

LIFTorch is a Rockpool and torch module which provides a trainable simulation of LIF spiking neurons. LinearTorch is a linear weight matrix, comparable to the torch nn.Linear module.

To build simple feed-forward networks, we use the Sequential combinator from Rockpool, which functions like the torch nn.Sequential combinator.

By default, Rockpool allows you to train the time constants \(\tau\) and other parameters of an SNN. For simplicity, here we’ll define them as constant (i.e. non-trainable), using the Constant parameter decorator.

The network architecture we will use is shown below; a simple two-layer SNN.

[16]:
Image('mnist-architecture.png')
[16]:
../_images/basics_hello_MNIST_14_0.png
[17]:
# - Import network packages
from rockpool.nn.modules import LIFTorch, LinearTorch
from rockpool.nn.combinators import Sequential
from rockpool.parameters import Constant

# - Define a simple network
num_hidden = 64
tau_mem = Constant(100e-3)
tau_syn = Constant(50e-3)
threshold = Constant(1.)
bias = Constant(0.)

# - Define a two-layer feed-forward SNN
snn = Sequential(
    LinearTorch((input_channels, num_hidden)),
    LIFTorch(num_hidden, tau_syn = tau_syn, tau_mem = tau_mem, threshold = threshold, bias = bias, dt = dt),

    LinearTorch((num_hidden, num_classes)),
    LIFTorch(num_classes, tau_syn = tau_syn, tau_mem = tau_mem, threshold = threshold, bias = bias, dt = dt)
)
print(snn)
TorchSequential  with shape (784, 10) {
    LinearTorch '0_LinearTorch' with shape (784, 64)
    LIFTorch '1_LIFTorch' with shape (64, 64)
    LinearTorch '2_LinearTorch' with shape (64, 10)
    LIFTorch '3_LIFTorch' with shape (10, 10)
}

Let’s examine the un-trained output of this network. We simulate the network identically as with the single LIF neuron above, just by passing input data to the module. We’ll use the single data sample we encoded above. The record = True argument tells Rockpool to record all the internal state of the SNN during evolution.

[18]:
# - Simulate the untrained network, record internal state
output, _, rec_dict = snn(data, record = True)

# - Display the internal state and output
plt.plot(rec_dict['1_LIFTorch']['vmem'][0].detach());
plt.plot([0, num_timesteps], [threshold] * 2, 'k:')
plt.ylim([-2.1, 1.1])
plt.xlabel('Time (dt)')
plt.ylabel('$V_m$ (a.u.)')
plt.title('Hidden $V_m$')

plt.figure()
plt.plot(rec_dict['3_LIFTorch']['vmem'][0].detach());
plt.plot([0, num_timesteps], [threshold] * 2, 'k:')
plt.ylim([-2.1, 1.1])
plt.xlabel('Time (dt)')
plt.ylabel('$V_m$ (a.u.)')
plt.title('Output $V_m$')

plt.figure()
t, n = torch.nonzero(rec_dict['3_LIFTorch_output'][0].detach(), as_tuple = True)
plt.scatter(t, n, marker = '|', label='Output events')
plt.plot(0, class_idx, 'g>', markersize=20, label='Target class')
plt.ylim([-1, num_classes+1])
plt.xlim([-1, num_timesteps+1])
plt.xlabel('Time (dt)')
plt.ylabel('Channel')
plt.title('Output (events)')
plt.legend();
../_images/basics_hello_MNIST_17_0.png
../_images/basics_hello_MNIST_17_1.png
../_images/basics_hello_MNIST_17_2.png

Optimization

Rockpool links in to the torch automatic differentiation pipeline, to provide end-to-end gradient-based training of SNNs. We can natively use losses and optmizers from torch.optim, such as Adam. The .astorch() method converts the Rockpool parameter dictionary into a form identical to that of torch.

[19]:
from torch.optim.adam import Adam

# - Initialise the optimizer with the network parameters
optimizer = Adam(snn.parameters().astorch())

Thanks to the close integration with torch, we can use CUDA-based GPU acceleration for training, the same as with a standard torch network.

[20]:
# - Determine which advice to use for training
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Now we will define a simple training loop, and a test function. This boilerplate code is identical to that required for any other torch model.

[21]:
def train(model, device, train_loader, optimizer):
  # - Prepare model for training
  model.train()
  losses = []

  # - Loop over the dataset for this epoch
  for data, class_idx in tqdm(train_loader, leave=False, desc='Training', unit='batch'):
    # - Encode input and target
    data = encode_poisson(data.squeeze(), num_timesteps)
    target = encode_class(class_idx, num_classes, num_timesteps)

    # - Zero gradients, simulate model
    optimizer.zero_grad()
    output, _, _ = model(data.to(device))

    # - Compute MSE loss and perform backward pass
    loss = torch.nn.functional.mse_loss(output, target.to(device))
    loss.backward()
    optimizer.step()

    # - Keep track of the losses
    losses.append(loss.item())

  return losses
[22]:
def test(model, device, test_loader):
    # - Prepare model for evaluation
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
      # - Loop over the dataset
      for data, class_idx in tqdm(test_loader, desc='Testing', leave=False, unit='batch'):
        # - Encode the input and target
        input = encode_poisson(data.squeeze(), num_timesteps)
        target = encode_class(class_idx, num_classes, num_timesteps)

        # - Evaluate the model
        output, _, _ = model(input.to(device))

        # - Compute loss and prediction
        test_loss += torch.nn.functional.mse_loss(output, target.to(device)).item()
        pred = output.sum(axis = 1).argmax(axis = 1, keepdim=True).cpu()
        correct += pred.eq(class_idx.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    accuracy = 100.0 * correct / len(test_loader.dataset)

    return test_loss, accuracy

Let’s see how well the untrained network performs.

[23]:
# - Initial loss and accuracy
test_loss, test_acc = test(snn.to(device), device, test_loader)
print(f'Initial test loss {test_loss}, accuracy {test_acc}%')
Initial test loss 0.000445678124576807, accuracy 9.5%

Now we can train for several epochs, and see if we can improve performance.

[24]:
# - Train some epochs
num_epochs = 3

all_losses = []
ep_loop = trange(num_epochs, desc='Training', unit='epoch')
for epoch in ep_loop:
  ep_loop.set_postfix(test_loss = test_loss, test_acc = f'{test_acc}%')
  losses = train(snn.to(device), device, train_loader, optimizer)
  all_losses.extend(losses)

  # - Get test loss and accuracy
  test_loss, test_acc = test(snn.to(device), device, test_loader)

plt.plot(all_losses)
plt.xlabel('Batch')
plt.ylabel('MSE loss')
plt.title('Training loss');
../_images/basics_hello_MNIST_28_7.png
[25]:
# - Trained loss and accuracy
print(f'Final test loss {test_loss}, accuracy {test_acc}%')
Final test loss 0.00033636679723858835, accuracy 86.23%

Looks good! We definitely learned something. Let’s take a look at the output on a single sample.

[41]:
# - Pull a single test sample and display it
frame, class_idx = train_data[1]
data = encode_poisson(frame, num_timesteps)
target = encode_class(torch.tensor(class_idx), num_classes, num_timesteps)

plt.figure(figsize=(1, 1))
plt.imshow(frame[0])
plt.xticks([])
plt.yticks([])

plt.figure()
b, t, n = torch.nonzero(data, as_tuple = True)
plt.scatter(t * dt, n, marker='|')
plt.xlabel('Time (s)')
plt.ylabel('Channel')
plt.title('Input')

b, t, n = torch.nonzero(target, as_tuple = True)
plt.figure()
plt.scatter(t * dt, n, marker='|')
plt.ylim([-1, num_classes+1])
plt.xlabel('Time (s)')
plt.ylabel('Channel')
plt.title('Target');
../_images/basics_hello_MNIST_31_0.png
../_images/basics_hello_MNIST_31_1.png
../_images/basics_hello_MNIST_31_2.png
[45]:
# - Simulate the network with this sample
output, _, _ = snn.cpu()(data)
pred = output.sum(axis = 1).argmax(axis = 1, keepdim=True).cpu()

# - Display the network output
b, t, n = torch.nonzero(output, as_tuple=True)

plt.scatter(t, n, marker='|', label='Output events')
plt.plot(0, class_idx, 'g>', markersize=20, label='Target class')
plt.plot(100, pred, '<', c='orange', markersize=20, label='Predicted class')
plt.ylim([-1, num_classes+1])
plt.xlim([-1, num_timesteps+1])
plt.xlabel('Time (dt)')
plt.ylabel('Channel')
plt.title('Output (events)')
plt.legend();
../_images/basics_hello_MNIST_32_0.png

Next steps

The performance of the network could of course be improved in various ways:

  • Deeper or more complex network architecture

  • More hidden neurons

  • Training time constants \(\tau\) and other parameters

  • Different loss function or output encoding

Now you’ve seen how to solve a simple vision task with Rockpool, check out other tutorials for more difficult temporal tasks, advanced training approaches, and hardware-aware training and deployment.