This page was generated from /builds/synsense/rockpool/docs/tutorials/easter/easter-snn-images.ipynb. Interactive online version: Binder badge

๐Ÿฐ Easter with Rockpool ๐Ÿฅš๏ƒ

# - Switch off warnings
import warnings

# - Imports
from rockpool import TSEvent, TSContinuous

import numpy as np
from tqdm.autonotebook import trange

    from rich import print

import jax
from jax.config import config

config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 4]
plt.rcParams['figure.dpi'] = 300

Weโ€™ll define a dataset class that encompasses the task of mapping frozen poisson inputs to RGB images. Weโ€™ll define a number of input channels to receive the poissonian noise; the output of the network will be several spiking neurons per pixel column, encoding the red, green and blue components of an image pixel. The images will be encoded by stretching them out across time, so that some number of time steps corresponds to a pixel row.

from IPython.display import Image

The dataset class will specify the task, and return samples for training and testing

# - Dataset
class PoissonToImages:
    def __init__(self, filenames: str,
                 num_input_channels: int,
                 num_output_channels_per_pixel: int,
                 time_steps_per_pixel: int = 20,
                 poisson_rate = 0.2,
        self.Nin = num_input_channels
        self.time_steps_per_pixel = time_steps_per_pixel
        self.num_output_channels_per_pixel = num_output_channels_per_pixel

        # - Read filenames in turn
        self.Nclasses = 0
        self.images = []
        for this_file in filenames:
                im = plt.imread(this_file)
                self.Nclasses += 1


        # - Generate poisson input and targets for classes
        self.input_raster = []
        self.target_raster = []
        for class_index in range(self.Nclasses):
            # - Generate output target raster for this class
            this_target = self.encode(self.images[class_index])

            # - Generate poisson input raster for this class
            self.input_raster.append(np.random.rand(this_target.shape[0], num_input_channels) < poisson_rate)

        self.Nout = self.target_raster[-1].shape[1]

    def __getitem__(self, index):
        return self.input_raster[index], self.target_raster[index]

    def __len__(self):
        return self.Nclasses

    def quantize(self, image: np.ndarray) -> np.ndarray:
        im_quant = None

        if image.ndim == 2:
            # - Quantise down to `num_output_channels`
            num_levels = 2 ** num_output_channels_per_pixel
            im_quant = np.round(image * num_levels).astype('uint8')
            im_quant = np.unpackbits(im_quant, axis=2, bitorder = 'little')
            im_quant = im_quant[:, :, :num_output_channels]

        elif image.ndim == 3:
            # - Quanitise each RGB channel
            channels_per_RGB = int(np.floor((self.num_output_channels_per_pixel / 3)))
            num_levels = 2 ** channels_per_RGB-1
            image = np.round(image * num_levels).astype('uint8')

            im_quant = []
            for ch in range(3):
                this_ch = np.unpackbits(image[:, :, ch:ch+1], axis = 2, bitorder = 'little')
                im_quant.append(this_ch[:, :, :channels_per_RGB])

            im_quant = np.concatenate(im_quant, axis = 2)

        return im_quant

    def encode(self, image):
        target = np.expand_dims(self.quantize(image), axis=1)
        target = np.repeat(target, self.time_steps_per_pixel, axis=1)
        target = np.reshape(target, (image.shape[0] * self.time_steps_per_pixel, -1))
        return target

    def decode(self, output, class_index):
        output = np.reshape(output, (self.images[class_index].shape[0], self.time_steps_per_pixel, self.images[class_index].shape[1], -1 ))
        output = np.mean(output, axis = 1)
        return output

Now we can load some image files and generate a task and dataset. Weโ€™ll use 1000 input channels, and a single R, G and B output channel per pixel column.

# - Task configuration
Nin = 1000
Nout_per_pix = 3
input_scale = 1.

# - Load a data set and visualise
ds = PoissonToImages(['egg-1.png',
print(f'Dataset: Loaded {len(ds)} images.')

for class_index in range(len(ds)):
    plt.subplot(1, 5, (1, 4))
    plt.imshow(ds.input_raster[class_index].T, aspect='auto')
    plt.xlabel('Time (step)')
    plt.ylabel('Input channel')

    plt.subplot(1, 5, 5)
    plt.title(f'Class {class_index}')
Dataset: Loaded 3 images.

We can also see exactly what the desired output of the network is, with image rows stretched along time, and RGB channels interleaved.

for class_index in range(len(ds)):
    plt.subplot(1, 5, (1, 4))
    plt.imshow(ds.target_raster[class_index].T, aspect='auto')
    plt.xlabel('Time (step)')
    plt.ylabel('Output channel')

    plt.subplot(1, 5, 5)
    plt.title(f'Class {class_index}')

Now we need a network to train for this task. Weโ€™ll use a simple feed-forward architecture with one hidden layer.

Image('network-01.png', width=200)

Weโ€™ll use spiking linear integrate-and-fire (LIF) neurons for the hidden and output layers, using the Jax backend of Rockpool and the module LIFJax. This module provides a gradient-descent trainable layer of neurons, using Jax to optimise the neuron parameters. The weights in between the layers will use the LinearJax module, which is likewise trainable.

Letโ€™s pick 100 neurons for the hidden layer in this network.

# - Network configuration
Nin = ds.Nin
Nhidden = 100
Nout = ds.Nout

# - Rockpool modules
from rockpool.nn.modules import LIFJax, LinearJax
from rockpool.nn.modules.jax.jax_lif_ode import LIFODEJax
from rockpool.nn.combinators import Sequential
from import jax_loss as jl
from rockpool.parameters import Constant

bias_hidden = np.ones(Nhidden)
bias_out = np.ones(Nout)

net = Sequential(
    LinearJax((Nin, Nhidden), has_bias = False),
    LIFODEJax((Nhidden), bias = bias_hidden, threshold = Constant(1.)),
    LinearJax((Nhidden, Nout), has_bias = False),
    LIFODEJax((Nout), bias = bias_out, threshold = Constant(1.)),
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Could not import package: No module named 'iaf_nest'
JaxSequential  with shape (1000, 45) {
    LinearJax '0_LinearJax' with shape (1000, 100)
    LIFODEJax '1_LIFODEJax' with shape (100, 100)
    LinearJax '2_LinearJax' with shape (100, 45)
    LIFODEJax '3_LIFODEJax' with shape (45, 45)

Now we can import Jax and build a simple loss function, which will be the MSE between the network output and the desired target. We use jax.jit to compile the optimiser update function as well as the loss function โ€” these will be automatically targeted to CPU, GPU or TPU by Jax, depending on whatโ€™s available. We use jax.value_and_grad to automatically differentiate the loss function over the network parameters

# - Import the bounds convenience functions
from import bounds_cost, make_bounds

# - Generate a set of pre-configured bounds
lower_bounds, upper_bounds = make_bounds(net.parameters())
print('lower_bounds: ', lower_bounds,
      'upper_bounds: ', upper_bounds)
    '0_LinearJax': {'weight': -inf},
    '1_LIFODEJax': {'bias': -inf, 'tau_mem': -inf, 'tau_syn': -inf},
    '2_LinearJax': {'weight': -inf},
    '3_LIFODEJax': {'bias': -inf, 'tau_mem': -inf, 'tau_syn': -inf}
    '0_LinearJax': {'weight': inf},
    '1_LIFODEJax': {'bias': inf, 'tau_mem': inf, 'tau_syn': inf},
    '2_LinearJax': {'weight': inf},
    '3_LIFODEJax': {'bias': inf, 'tau_mem': inf, 'tau_syn': inf}
dt = 1e-3
lower_bounds['1_LIFODEJax']['tau_mem'] = 11 * dt
lower_bounds['1_LIFODEJax']['tau_syn'] = 11 * dt
lower_bounds['3_LIFODEJax']['tau_mem'] = 11 * dt
lower_bounds['3_LIFODEJax']['tau_syn'] = 11 * dt

if 'threshold' in lower_bounds['1_LIFODEJax']:
    lower_bounds['1_LIFODEJax']['threshold'] = 0.1
    lower_bounds['3_LIFODEJax']['threshold'] = 0.1
# - Jax imports
import jax
from jax.experimental.optimizers import adam
from itertools import count

# - Initialise optimiser
learning_rate = 1e-4
bounds_penalty = 10.
init_fun, update_fun, get_params = adam(learning_rate)
opt_state = init_fun(net.parameters())
update_fun = jax.jit(update_fun)

# - Loss function
def loss_vgf(params, net, input, target):
    net = net.set_attributes(params)
    net = net.reset_state()
    output, _, _ = net(input)

    # - Impose the bounds
    bounds = bounds_cost(params, lower_bounds, upper_bounds)

    return jl.mse(output, target) + bounds_penalty * bounds

Now we can specify the number of epochs to train for, and use a simple training loop to perform gradient descent learning. Weโ€™ll keep track of the loss values along the way to visualise.

# - Compose samples into batches
batch_input = []
batch_target = []
for this_input, this_target in ds:

batch_input = np.stack(batch_input)
batch_target = np.stack(batch_target)
# - Configure learning
num_epochs = 35000

# - Training loop
loss_t = []
sample_index = count()
t = trange(num_epochs, desc='Training', unit='Epoch')

for epoch in t:
    # - Get parameters
    opt_parameters = get_params(opt_state)

    # - Compute loss and gradient
    l, g = loss_vgf(opt_parameters, net, batch_input * input_scale, batch_target)
    if l > bounds_penalty:
        l = l % bounds_penalty
    t.set_postfix({'loss': l.item()}, refresh = False)

    # - Update optimiser
    opt_state = update_fun(next(sample_index), g, opt_state)
plt.ylabel('Loss value')
plt.title('Training loss');

The loss decreased, indicating that the network is learning to produce the desired output. Letโ€™s visualise the output of the network over the various classes to check.

# - Test training
net = net.set_attributes(get_params(opt_state))

plt.figure(figsize=(12, 10))

for class_index in range(len(ds)):
    inp, target = ds[class_index]
    out, _, rec_state = net(inp * input_scale)

    plt.subplot(2, 3, 1 + class_index)
    plt.imshow(ds.decode(out, class_index))

    plt.subplot(2, 3, 4 + class_index)
    plt.imshow(ds.decode(target, class_index))
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

But what is the network doing internally? Letโ€™s take a look at the spiking activity of the hidden and output layers.

# TSEvent.from_raster(rec_state['1_LIFODEJax_output'][0], dt = 1e-3).plot();
plt.imshow(rec_state['1_LIFODEJax_output'][0].T, aspect='auto', origin='lower')
plt.title('Hidden layer events')

TSEvent.from_raster(rec_state['3_LIFODEJax_output'][0], dt = 1e-3).plot();
plt.title('Readout events');

This toolchain lets you train all neuron parameters, not only weights and biases. Letโ€™s take a look at the final distribution of time constants, which were initially identical.

plt.hist(net[1].tau_mem * 1e3, 20)
plt.xlabel('$\\tau$ (ms)')
plt.title('Time constant distribution');

And of course the weights have also been learned to perform the task correctly.

plt.hist(np.ravel(net[0].weight), 20)
plt.xlabel('Weight value')
plt.title('Weight distribution');
plt.imshow(net[0].weight.T, aspect='auto');
plt.xlabel('Input channel')
plt.ylabel('Output channel');