This page was generated from docs/tutorials/jax_lif_sgd.ipynb. Interactive online version: Binder badge

⚡️ Training a spiking network with Jax

This tutorial demonstrates using Rockpool and a Jax-accelerated LIF feed-forward neuron layer to perform gradient descent training of all network parameters. The result is a trained spiking layer which can generate a pre-defined signal from a noisy spiking input.

Requirements and housekeeping

This example requires the Rockpool package from SynSense, as well as jax and its dependencies.

[1]:
# - Switch off warnings
import warnings

warnings.filterwarnings("ignore")

# - Rockpool imports
from rockpool import TSEvent, TSContinuous
from rockpool.nn.modules import LIFJax, LinearJax, ExpSynJax
from rockpool.nn.modules.jax.jax_lif_ode import LIFODEJax
from rockpool.nn.combinators import Sequential
from rockpool.parameters import Constant

# - Typing
from typing import Callable, Dict, Tuple
import types

# - Numpy
import numpy as np
import copy

# - Pretty printing
try:
    from rich import print
except:
    pass

# TQDM
from tqdm.autonotebook import tqdm

# - Plotting imports and config
import sys
!{sys.executable} -m pip install --quiet matplotlib
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = [12, 4]
plt.rcParams["figure.dpi"] = 300

Signal generation from frozen noise task

We will use a single feed-forward layer of spiking neurons to convert a chosen pattern of random input spikes over time, into a pre-defined temporal signal with complex dynamics.

The network architecture is strictly feedforward, but the spiking neurons nevertheless contain temporal dynamics in their synaptic and membrane signals, with explicit time constants.

Some number of input channels Nin will contain independent Poisson spikes at some rate spiking_prob/dt. A single output channel should generate a chirp signal with increasing frequency, up to a maximum of chirp_freq_factor. You can play with these parameters below.

[2]:
# - Define input and target
Nin = 200
dt = 1e-3
chirp_freq_factor = 10
dur_input = 1000e-3

# - Generate a time base
T = int(np.round(dur_input / dt))
timebase = np.linspace(0, (T - 1) * dt, T)

# - Generate a chirp signal as a target
chirp = np.atleast_2d(np.sin(timebase * 2 * np.pi * (timebase * chirp_freq_factor))).T
target_ts = TSContinuous(timebase, chirp, periodic=True, name="Target chirp")

# - Generate a Poisson frozen random spike train
spiking_prob = 0.01
input_sp_raster = np.random.rand(T, Nin) < spiking_prob
input_sp_ts = TSEvent.from_raster(
    input_sp_raster, name="Input spikes", periodic=True, dt=dt
)

# - Plot the input and target signals
plt.figure()
input_sp_ts.plot(s=4)
(target_ts * Nin / 2 + Nin / 2).plot(color="orange", lw=2)
plt.legend()
plt.title("Input and target");
../_images/tutorials_jax_lif_sgd_6_0.png

LIF neuron

The spiking neuron we will use is a leaky integrate-and-fire spiking neuron (“LIF” neuron). This neuron recevies input spike trains \(S_{in}(t) = \sum_j\delta(t-t_j)\), which are integrated via weighted exponential synapses. Synaptic currents are then integrated into a neuron state (“membrane potential”) \(V_{mem}\).

The neuron obeys the dynamics

\[\tau_{mem}\cdot\dot{V}_{mem} + V_{mem} = {I}_{syn} + I_{bias} + \sigma\zeta(t)\]
\[\tau_{syn}\cdot\dot{I}_{syn} + I_{syn} = 0\]
\[I_{syn} += W_{in} \cdot S_{in}(t)\]

Where \(\tau_{mem}\) and \(\tau_{syn}\) are membrane and synaptic time constants; \(I_{bias}\) is a constant bias current for each neuron; \(\sigma\zeta(t)\) is a white noise process with std. dev. \(\sigma\).

Output spikes are generated when \(V_{mem}\) crosses the firing threshold \(V_{th} = 0\). This process generates a spike train \(S(t)\) as a series of delta functions, and causes a subtractive reset of \(V_{mem}\):

\[V_{mem} > V_{th} \rightarrow S(t) = H(V_{mem}(t)), V_{mem} = V_{mem} - 1\]

The analog output signal is generated using a surrogate

\[U(t) = \tanh(x + 1) / 2 + 0.5\]

The output of the network \(o(t)\) is therefore given by

\[o(t) = W_{out} \cdot S(t)\]

For more detail, see the documentation for the Jax module LIFJax.

Build a network

The network architecture is a single feedforward layer, with weighted spiking inputs and outputs. Spiking is generated via a function that provides a surrogate gradient in the backwards pass. This permits propagation of an error gradient through the layer, making gradient-descent training possible.

For this regression task we will also use an exponential synapse layer to perfprm temporal smoothing of the output. Regressing to a smooth signal is much easier with a continuous output signal, than using the spike deltas alone.

[3]:
# - Network size
N = 50
Nout = 1
input_scale = 20.0
[4]:
# - Generate a network using the sequential combinator
modFFwd = Sequential(
    LinearJax((Nin, N), has_bias=False),
    LIFODEJax(N, dt=dt),
    ExpSynJax(N),
    LinearJax((N, Nout)),
)

print(modFFwd)
JaxSequential  with shape (200, 1) {
    LinearJax '0_LinearJax' with shape (200, 50)
    LIFODEJax '1_LIFODEJax' with shape (50, 50)
    ExpSynJax '2_ExpSynJax' with shape (50,)
    LinearJax '3_LinearJax' with shape (50, 1)
}

Simulate initial state of network

If we simulate the untrained network with our random input spikes, we don’t expect anything sensible to come out. Let’s do this, and take a look at how the network behaves.

[5]:
# - Randomise the network state
modFFwd.reset_state()

# - Evolve with the frozen noise spiking input
tsOutput, new_state, record_dict = modFFwd(input_sp_raster * input_scale, record=True)

# - Plot the analog output
plt.figure()
plt.plot(tsOutput[0])
[5]:
[<matplotlib.lines.Line2D at 0x7fa407309e20>]
../_images/tutorials_jax_lif_sgd_13_1.png

We can also examine the internal state of the network, by interrogating record_dict:

[6]:
# - Make a function that converts ``record_dict``
def plot_record_dict(rd):
    Isyn_ts = TSContinuous.from_clocked(
        rd["1_LIFODEJax"]["isyn"][0, :, :, 0], dt, name="Synaptic currents $I_{syn}$"
    )
    Vmem_ts = TSContinuous.from_clocked(
        rd["1_LIFODEJax"]["vmem"][0], dt, name="Membrane potential $V_{mem}$"
    )
    spikes_ts = TSEvent.from_raster(
        rd["1_LIFODEJax"]["spikes"][0], dt, name="LIF layer spikes"
    )

    # - Plot the internal activity of selected neurons
    plt.figure()
    Isyn_ts.plot(stagger=1.1, skip=5)

    plt.figure()
    Vmem_ts.plot(stagger=1.1, skip=5)

    plt.figure()
    spikes_ts.plot(s=4)


plot_record_dict(record_dict)
../_images/tutorials_jax_lif_sgd_15_0.png
../_images/tutorials_jax_lif_sgd_15_1.png
../_images/tutorials_jax_lif_sgd_15_2.png

Training the network

In order to train the network we need to define a loss function to optimise. This function accepts a set of parameters, the network, the inputs and target for a trial, and computes an error (“loss”) for the trial. The loss computed by comparing the network output to the target using mean-squared error.

Usually you would add regularisation terms to the loss function, to make sure parameters don’t grow too large; to encourage low firing rates; etc. Generally you would want to also place bounds on the time constants, to prevent them becoming too small and causing numerical instability. See 🏃🏽‍♀️ Training a Rockpool network with Jax for more information.

[7]:
# - Import the convenience functions
from rockpool.training.jax_loss import bounds_cost, make_bounds

# - Generate a set of pre-configured bounds
lower_bounds, upper_bounds = make_bounds(modFFwd.parameters())
print("lower_bounds: ", lower_bounds, "upper_bounds: ", upper_bounds)
lower_bounds:
{
    '0_LinearJax': {'weight': -inf},
    '1_LIFODEJax': {'bias': -inf, 'tau_mem': -inf, 'tau_syn': -inf, 'threshold': -inf},
    '2_ExpSynJax': {'tau': -inf},
    '3_LinearJax': {'weight': -inf}
}
upper_bounds:
{
    '0_LinearJax': {'weight': inf},
    '1_LIFODEJax': {'bias': inf, 'tau_mem': inf, 'tau_syn': inf, 'threshold': inf},
    '2_ExpSynJax': {'tau': inf},
    '3_LinearJax': {'weight': inf}
}
[8]:
# - Impose a lower bound for the time constants
lower_bounds["1_LIFODEJax"]["tau_syn"] = 11 * dt
lower_bounds["1_LIFODEJax"]["tau_mem"] = 11 * dt
# lower_bounds['1_LIFODEJax']['threshold'] = 0.1
lower_bounds["2_ExpSynJax"]["tau"] = 11 * dt
[9]:
print("lower_bounds:", lower_bounds)
lower_bounds:
{
    '0_LinearJax': {'weight': -inf},
    '1_LIFODEJax': {'bias': -inf, 'tau_mem': 0.011, 'tau_syn': 0.011, 'threshold': -inf},
    '2_ExpSynJax': {'tau': 0.011},
    '3_LinearJax': {'weight': -inf}
}
[10]:
import rockpool.training.jax_loss as l
import jax.numpy as jnp

# - Define a loss function
def loss_mse(params, net, input, target):
    # - Reset the network state
    net = net.reset_state()

    # - Apply the parameters
    net = net.set_attributes(params)

    # - Evolve the network
    output, _, states = net(input, record=True)

    # - Impose the bounds
    bounds = bounds_cost(params, lower_bounds, upper_bounds)

    # - Return the loss
    return l.mse(output, target) + 100.0 * bounds

Below we define a training loop that uses a gradient-descent optimisation algorithm (“Adam”, provided by Jax) to iteratively optimise the network parameters. We keep track of the loss value for each iteration for later visualisation.

[11]:
# - Useful imports
from tqdm.autonotebook import tqdm
from copy import deepcopy
from itertools import count

# -- Import an optimiser to use and initalise it
import jax
from jax.example_libraries.optimizers import adam

# - Get the optimiser functions
init_fun, update_fun, get_params = adam(1e-4)

# - Initialise the optimiser with the initial parameters
params0 = copy.deepcopy(modFFwd.parameters())
opt_state = init_fun(modFFwd.parameters())

# - Get a compiled value-and-gradient function
loss_vgf = jax.jit(jax.value_and_grad(loss_mse))

# - Compile the optimiser update function
update_fun = jax.jit(update_fun)

# - Record the loss values over training iterations
loss_t = []
grad_t = []

num_epochs = 2000
[12]:
# - Loop over iterations
i_trial = count()
for _ in tqdm(range(num_epochs)):
    # - Get parameters for this iteration
    params = get_params(opt_state)

    # - Get the loss value and gradients for this iteration
    loss_val, grads = loss_vgf(params, modFFwd, input_sp_raster * input_scale, chirp)

    # - Update the optimiser
    opt_state = update_fun(next(i_trial), grads, opt_state)

    # - Keep track of the loss
    loss_t.append(loss_val)
100%|██████████| 2000/2000 [00:16<00:00, 118.07it/s]
[13]:
# - Plot the loss
plt.figure()
plt.plot(loss_t)
plt.yscale("log")
plt.ylabel("Loss")
plt.xlabel("Training iteration")
plt.title("Training progress");
../_images/tutorials_jax_lif_sgd_25_0.png

Plot the ouput of the trained network

The MSE loss decreased — so far, so good. But what has the network learned?

[14]:
# - Simulate with trained parameters
modFFwd = modFFwd.set_attributes(get_params(opt_state))
modFFwd = modFFwd.reset_state()
output_ts, _, record_dict = modFFwd(input_sp_raster * input_scale)

# - Compare the output to the target
plt.plot(output_ts[0])
plt.plot(chirp, lw=3)
plt.title("Output vs target")

# - Plot the internal state of selected neurons
plot_record_dict(record_dict)
../_images/tutorials_jax_lif_sgd_28_0.png
../_images/tutorials_jax_lif_sgd_28_1.png
../_images/tutorials_jax_lif_sgd_28_2.png
../_images/tutorials_jax_lif_sgd_28_3.png

Plot the network parameters

Let’s see how much the network parameters changed. Since the initial parameter set was random, we’ll plot the difference between the trained and initial parameters \(\theta^* - \theta\).

[15]:
modIn = modFFwd[0]
modLIF = modFFwd[1]
modOut = modFFwd[3]
[16]:
# - Plot the change in input weights
plt.figure()
w_diff = modIn.weight - params0["0_LinearJax"]["weight"]
lim = np.max(np.abs(w_diff))
plt.imshow(w_diff, aspect="auto")
plt.title("Input weight change $w^*_{in}-w_{in}$")
plt.clim([-lim, lim])
plt.set_cmap("PuOr")

# - Plot the change in output weights
plt.figure()
plt.stem(modOut.weight - params0["3_LinearJax"]["weight"])
plt.title("Output weight change $w^*_{out}-w_{out}$")

# - Plot the distribution of final time constants
plt.figure()
plt.hist(modLIF.tau_mem * 1e3, 20)
plt.xlabel("$\\tau_{mem}$ (ms)")
plt.title("Histogram of membrane time constants $\\tau_{mem}$")

plt.figure()
plt.hist(modLIF.tau_syn.flatten() * 1e3, 20)
plt.xlabel("$\\tau_{syn}$ (ms)")
plt.title("Histogram of synaptic time constants $\\tau_{syn}$")

# - Plot the distribution of final biases
plt.figure()
plt.hist(modLIF.bias, 20)
plt.xlabel("Bias value $I_{bias}$")
plt.title("Histogram of neuron biases $I_{bias}$");
../_images/tutorials_jax_lif_sgd_31_0.png
../_images/tutorials_jax_lif_sgd_31_1.png
../_images/tutorials_jax_lif_sgd_31_2.png
../_images/tutorials_jax_lif_sgd_31_3.png
../_images/tutorials_jax_lif_sgd_31_4.png

The power of automatic differentiation is that almost for free, we get to optimise not just the weights, but all time constants and biases simultaneously. And we didn’t have to compute the gradients by hand!

As a sanity check, let’s see how the trained network responds if we give it a different random noise input.

[17]:
spiking_prob = 0.01
sp_rand_ts = np.random.rand(T, Nin) < spiking_prob
[18]:
# - Simulate with trained parameters
modFFwd = modFFwd.set_attributes(get_params(opt_state))
modFFwd = modFFwd.reset_state()
output_ts, _, record_dict = modFFwd(sp_rand_ts * input_scale)

# - Compare the output to the target
plt.plot(output_ts[0])
plt.plot(chirp, lw=3)
plt.title("Output vs target")

# - Plot the internal state of selected neurons
plot_record_dict(record_dict)
../_images/tutorials_jax_lif_sgd_34_0.png
../_images/tutorials_jax_lif_sgd_34_1.png
../_images/tutorials_jax_lif_sgd_34_2.png
../_images/tutorials_jax_lif_sgd_34_3.png

As expected, the network doesn’t do anything sensible with data it has never seen.

Summary

This approach can be used identically to train recurrent spiking networks, as well as multi-layer (i.e. deep) networks.