Source code for nn.modules.jax.lif_jax

"""
Implements a leaky integrate-and-fire neuron module with a Jax backend
"""

import jax

from rockpool.nn.modules.jax.jax_module import JaxModule
from rockpool.nn.modules.native.linear import kaiming
from rockpool.parameters import Parameter, State, SimulationParameter
from rockpool import TSContinuous, TSEvent
from rockpool.graph import (
    GraphModuleBase,
    as_GraphHolder,
    LIFNeuronWithSynsRealValue,
    LinearWeights,
)

import numpy as onp

from jax import numpy as np
from jax.tree_util import Partial
from jax.lax import scan
import jax.random as rand

from typing import Optional, Tuple, Union, Callable
from rockpool.typehints import FloatVector, P_ndarray, JaxRNGKey, P_float, P_int

__all__ = ["LIFJax"]

GRADIENT_LIMIT = np.inf


# - Surrogate functions to use in learning
def sigmoid(x: FloatVector, threshold: FloatVector) -> FloatVector:
    """
    Sigmoid function

    :param FloatVector x: Input value

    :return FloatVector: Output value
    """
    return np.tanh(x + 1 - threshold) / 2 + 0.5


@jax.custom_jvp
def step_pwl(
    x: FloatVector,
    threshold: FloatVector,
    window: FloatVector = 0.5,
    max_spikes_per_dt: float = 2.0**16,
) -> FloatVector:
    """
    Heaviside step function with piece-wise linear derivative to use as spike-generation surrogate

    Args:
        x (float):          Input value
        threshold (float):  Firing threshold
        window (float): Learning window around threshold. Default: 0.5
        max_spikes_per_dt (float): Maximum number of spikes that may be produced each dt. Default: ``2**16``
    Returns:
        float: Number of output events for each input value
    """
    spikes = (x >= threshold) * np.floor(x / threshold)
    return np.clip(spikes, 0.0, max_spikes_per_dt)


@step_pwl.defjvp
def step_pwl_jvp(primals, tangents):
    x, threshold, window, max_spikes_per_dt = primals
    x_dot, threshold_dot, window_dot, max_spikes_per_dt_dot = tangents
    primal_out = step_pwl(*primals)
    tangent_out = (x >= (threshold - window)) * (
        x_dot / threshold - threshold_dot * x / (threshold**2)
    )
    return primal_out, tangent_out


@jax.custom_vjp
def clip_gradient(lo, hi, x):
    return x  # identity function


def clip_gradient_fwd(lo, hi, x):
    return x, (lo, hi)  # save bounds as residuals


def clip_gradient_bwd(res, g):
    lo, hi = res
    return (
        None,
        None,
        np.clip(g, lo, hi),
    )  # use None to indicate zero cotangents for lo and hi


clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)


[docs]class LIFJax(JaxModule): """ A leaky integrate-and-fire spiking neuron model, with a Jax backend This module implements the update equations: .. math :: I_{syn} += S_{in}(t) + S_{rec} \\cdot W_{rec} I_{syn} *= \\exp(-dt / \\tau_{syn}) V_{mem} *= \\exp(-dt / \\tau_{mem}) V_{mem} += I_{syn} + b + \\sigma \\zeta(t) where :math:`S_{in}(t)` is a vector containing ``1`` (or a weighed spike) for each input channel that emits a spike at time :math:`t`; :math:`b` is a :math:`N` vector of bias currents for each neuron; :math:`\\sigma\\zeta(t)` is a Wiener noise process with standard deviation :math:`\\sigma` after 1s; and :math:`\\tau_{mem}` and :math:`\\tau_{syn}` are the membrane and synaptic time constants, respectively. :math:`S_{rec}(t)` is a vector containing ``1`` for each neuron that emitted a spike in the last time-step. :math:`W_{rec}` is a recurrent weight matrix, if recurrent weights are used. :math:`b` is an optional bias current per neuron (default 0.). :On spiking: When the membrane potential for neuron :math:`j`, :math:`V_{mem, j}` exceeds the threshold voltage :math:`V_{thr}`, then the neuron emits a spike. The spiking neuron subtracts its own threshold on reset. .. math :: V_{mem, j} > V_{thr} \\rightarrow S_{rec,j} = 1 V_{mem, j} = V_{mem, j} - V_{thr} Neurons therefore share a common resting potential of ``0``, have individual firing thresholds, and perform subtractive reset of ``-V_{thr}``. """
[docs] def __init__( self, shape: Union[Tuple, int], tau_mem: Optional[FloatVector] = None, tau_syn: Optional[FloatVector] = None, bias: Optional[FloatVector] = None, w_rec: Optional[FloatVector] = None, has_rec: bool = False, weight_init_func: Optional[Callable[[Tuple], np.ndarray]] = kaiming, threshold: Optional[FloatVector] = None, noise_std: float = 0.0, max_spikes_per_dt: P_float = 2.0**16, dt: float = 1e-3, rng_key: Optional[JaxRNGKey] = None, spiking_input: bool = False, spiking_output: bool = True, *args, **kwargs, ): """ Instantiate an LIF module Args: shape (tuple): Either a single dimension ``(Nout,)``, which defines a feed-forward layer of LIF modules with equal amounts of synapses and neurons, or two dimensions ``(Nin, Nout)``, which defines a layer of ``Nin`` synapses and ``Nout`` LIF neurons. tau_mem (Optional[np.ndarray]): An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default. tau_syn (Optional[np.ndarray]): An optional array with concrete initialisation data for the synaptic time constants. If not provided, 20ms will be used by default. bias (Optional[np.ndarray]): An optional array with concrete initialisation data for the neuron bias currents. If not provided, 0.0 will be used by default. w_rec (Optional[np.ndarray]): If the module is initialised in recurrent mode, you can provide a concrete initialisation for the recurrent weights, which must be a square matrix with shape ``(Nout, Nin)``. has_rec (bool): If ``True``, module provides a recurrent weight matrix. Default: ``False``, no recurrent connectivity. weight_init_func (Optional[Callable[[Tuple], np.ndarray]): The initialisation function to use when generating weights. Default: ``None`` (Kaiming initialisation) threshold (FloatVector): An optional array specifying the firing threshold of each neuron. If not provided, ``1.`` will be used by default. noise_std (float): The std. dev. after 1s of the noise added to membrane state variables. Default: ``0.0`` (no noise). max_spikes_per_dt (float): The maximum number of events that will be produced in a single time-step. Default: ``2**16``. dt (float): The time step for the forward-Euler ODE solver. Default: 1ms rng_key (Optional[Any]): The Jax RNG seed to use on initialisation. By default, a new seed is generated. """ # - Check shape argument if np.size(shape) == 1: shape = (np.array(shape).item(), np.array(shape).item()) if np.size(shape) > 2: raise ValueError( "`shape` must be a one- or two-element tuple `(Nin, Nout)`." ) # - Call the superclass initialiser super().__init__( shape=shape, spiking_input=spiking_input, spiking_output=spiking_output, *args, **kwargs, ) # - Seed RNG if rng_key is None: rng_key = rand.PRNGKey(onp.random.randint(0, 2**63)) _, rng_key = rand.split(np.array(rng_key, dtype=np.uint32)) # - Initialise state self.rng_key: Union[np.ndarray, State] = State( rng_key, init_func=lambda _: rng_key ) self.n_synapses = shape[0] // shape[1] """ (int) Number of input synapses per neuron """ if self.n_synapses * shape[1] != self.size_in: raise ValueError( "You must specify an integer number of synapses per neuron." ) # - Should we be recurrent or FFwd? self._has_rec: bool = SimulationParameter(has_rec) if isinstance(has_rec, jax.core.Tracer) or has_rec: self.w_rec: P_ndarray = Parameter( w_rec, shape=(self.size_out, self.size_in), init_func=weight_init_func, family="weights", cast_fn=np.array, ) """ (Tensor) Recurrent weights `(Nout, Nin)` """ else: self.w_rec: P_ndarray = SimulationParameter( np.zeros((self.size_out, self.size_in)) ) # - Set parameters self.tau_mem: P_ndarray = Parameter( tau_mem, shape=[(self.size_out,), ()], init_func=lambda s: np.ones(s) * 20e-3, family="taus", cast_fn=np.array, ) """ (np.ndarray) Membrane time constants `(Nout,)` or `()` """ self.tau_syn: P_ndarray = Parameter( tau_syn, "taus", init_func=lambda s: np.ones(s) * 20e-3, shape=[ ( self.size_out, self.n_synapses, ), ( 1, self.n_synapses, ), (), ], cast_fn=np.array, ) """ (np.ndarray) Synaptic time constants `(Nout,)` or `()` """ self.bias: P_ndarray = Parameter( bias, "bias", init_func=lambda s: np.zeros(s), shape=[(self.size_out,), ()], cast_fn=np.array, ) """ (np.ndarray) Neuron bias currents `(Nout,)` or `()` """ self.threshold: P_ndarray = Parameter( threshold, "threshold", shape=[(self.size_out,), ()], init_func=np.ones, cast_fn=np.array, ) """ (np.ndarray) Firing threshold for each neuron `(Nout,)` or `()`""" self.dt: P_float = SimulationParameter(dt) """ (float) Simulation time-step in seconds """ self.noise_std: P_float = SimulationParameter(noise_std) """ (float) Noise injected on each neuron membrane per time-step """ # - Specify state self.spikes: P_ndarray = State(shape=(self.size_out,), init_func=np.zeros) """ (np.ndarray) Spiking state of each neuron `(Nout,)` """ self.isyn: P_ndarray = State( shape=(self.size_out, self.n_synapses), init_func=np.zeros ) """ (np.ndarray) Synaptic current of each neuron `(Nout, Nsyn)` """ self.vmem: P_ndarray = State(shape=(self.size_out,), init_func=np.zeros) """ (np.ndarray) Membrane voltage of each neuron `(Nout,)` """ self.max_spikes_per_dt: P_float = SimulationParameter(max_spikes_per_dt) """ (float) Maximum number of events that can be produced in each time-step """ # - Define additional arguments required during initialisation self._init_args = { "has_rec": has_rec, "weight_init_func": Partial(weight_init_func), }
[docs] def evolve( self, input_data: np.ndarray, record: bool = False, ) -> Tuple[np.ndarray, dict, dict]: """ Args: input_data (np.ndarray): Input array of shape ``(T, Nin)`` to evolve over record (bool): If ``True``, Returns: (np.ndarray, dict, dict): output, new_state, record_state ``output`` is an array with shape ``(T, Nout)`` containing the output data produced by this module. ``new_state`` is a dictionary containing the updated module state following evolution. ``record_state`` will be a dictionary containing the recorded state variables for this evolution, if the ``record`` argument is ``True``. """ # - Get input shapes, add batch dimension if necessary input_data, (vmem, spikes, isyn) = self._auto_batch( input_data, (self.vmem, self.spikes, self.isyn), ( (self.size_out,), (self.size_out,), (self.size_out, self.n_synapses), ), ) n_batches, n_timesteps, _ = input_data.shape # - Reshape data over separate input synapses input_data = input_data.reshape( n_batches, n_timesteps, self.size_out, self.n_synapses ) # - Get evolution constants alpha = np.exp(-self.dt / self.tau_mem) beta = np.exp(-self.dt / self.tau_syn) noise_zeta = self.noise_std * np.sqrt(self.dt) # - Generate membrane noise trace key1, subkey = rand.split(self.rng_key) noise_ts = noise_zeta * rand.normal( subkey, shape=(n_batches, n_timesteps, self.size_out) ) # - Single-step LIF dynamics def forward( state: Tuple[np.ndarray, np.ndarray, np.ndarray], inputs_t: Tuple[np.ndarray, np.ndarray], ) -> Tuple[ Tuple[np.ndarray, np.ndarray, np.ndarray], np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, ]: """ Single-step LIF dynamics for a recurrent LIF layer :param LayerState state: :param Tuple[np.ndarray, np.ndarray] inputs_t: (spike_inputs_ts, current_inputs_ts) :return: (state, Irec_ts, spikes_ts, Vmem_ts, Isyn_ts) state: (Tuple[np.ndarray, np.ndarray, np.ndarray]) Layer state at end of evolution Irec_ts: (np.ndarray) Recurrent input received at each neuron over time [T, N] spikes_ts: (np.ndarray) Logical spiking raster for each neuron [T, N] Vmem_ts: (np.ndarray) Membrane voltage of each neuron over time [T, N] Isyn_ts: (np.ndarray) Synaptic input current received by each neuron over time [T, N] """ # - Unpack inputs (sp_in_t, noise_in_t) = inputs_t # - Unpack state spikes, isyn, vmem = state # - Apply synaptic and recurrent input isyn = isyn + sp_in_t irec = np.dot(spikes, self.w_rec).reshape(self.size_out, self.n_synapses) isyn = isyn + irec # - Decay synaptic and membrane state vmem *= alpha isyn *= beta # - Integrate membrane potentials vmem = vmem + isyn.sum(1) + noise_in_t + self.bias # - Detect next spikes (with custom gradient) spikes = step_pwl(vmem, self.threshold, 0.5, self.max_spikes_per_dt) # - Apply subtractive membrane reset vmem = vmem - spikes * self.threshold # - Ensure gradients are reasonable # vmem = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, vmem) # spikes = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, spikes) # irec = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, irec) # isyn = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, isyn) # - Return state and outputs return (spikes, isyn, vmem), (irec, spikes, vmem, isyn) # - Map over batches @jax.vmap def scan_time(spikes, isyn, vmem, input_data, noise_ts): return scan(forward, (spikes, isyn, vmem), (input_data, noise_ts)) # - Evolve over spiking inputs state, (irec_ts, spikes_ts, vmem_ts, isyn_ts) = scan_time( spikes, isyn, vmem, input_data, noise_ts ) # - Generate return arguments outputs = spikes_ts states = { "spikes": spikes_ts[0, -1], "isyn": isyn_ts[0, -1], "vmem": vmem_ts[0, -1], "rng_key": key1, } record_dict = { "irec": irec_ts, "spikes": spikes_ts, "isyn": isyn_ts, "vmem": vmem_ts, } # - Clip parameter gradients # self.tau_mem = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, self.tau_mem) # self.tau_syn = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, self.tau_syn) # self.bias = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, self.bias) # self.threshold = clip_gradient(-GRADIENT_LIMIT, GRADIENT_LIMIT, self.threshold) # - Return outputs return outputs, states, record_dict
[docs] def as_graph(self) -> GraphModuleBase: # - Generate a GraphModule for the neurons neurons = LIFNeuronWithSynsRealValue._factory( self.size_in, self.size_out, f"{type(self).__name__}_{self.name}_{id(self)}", self, self.tau_mem, self.tau_syn, self.threshold, self.bias, self.dt, ) # - Include recurrent weights if present if self._has_rec: # - Weights are connected over the existing input and output nodes w_rec_graph = LinearWeights( neurons.output_nodes, neurons.input_nodes, f"{type(self).__name__}_recurrent_{self.name}_{id(self)}", self, self.w_rec, ) # - Return a graph containing neurons and optional weights return as_GraphHolder(neurons)
[docs] def _wrap_recorded_state(self, state_dict: dict, t_start: float = 0.0) -> dict: args = {"dt": self.dt, "t_start": t_start} return { "vmem": TSContinuous.from_clocked( np.squeeze(state_dict["vmem"][0]), name="$V_{mem}$", **args ), "isyn": TSContinuous.from_clocked( np.squeeze(state_dict["isyn"][0]), name="$I_{syn}$", **args ), "irec": TSContinuous.from_clocked( np.squeeze(state_dict["irec"][0]), name="$I_{rec}$", **args ), "spikes": TSEvent.from_raster( np.squeeze(state_dict["spikes"][0]), name="Spikes", **args ), }