Source code for nn.modules.torch.lif_neuron_torch

"""
Implement a LIF Neuron Module, using a Torch backend
"""
from typing import Union, Tuple, Any
import numpy as np
from rockpool.nn.modules.torch.torch_module import TorchModule
import torch
import rockpool.parameters as rp

from rockpool.typehints import P_float, P_tensor, P_int

__all__ = ["LIFNeuronTorch"]

# - Define a float / array type
FloatVector = Union[float, torch.Tensor]


class StepPWL(torch.autograd.Function):
    """
    Heaviside step function with piece-wise linear derivative to use as spike-generation surrogate

    :param torch.Tensor x: Input value

    :return torch.Tensor: output value and gradient function
    """

    @staticmethod
    def forward(ctx, data):
        ctx.save_for_backward(data)
        return torch.clamp(torch.floor(data + 1), 0)

    @staticmethod
    def backward(ctx, grad_output):
        (data,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[data < -0.5] = 0
        return grad_input


[docs]class LIFNeuronTorch(TorchModule): """ A leaky integrate-and-fire spiking neuron model This module implements the dynamics: .. math :: \\tau_{mem} \\dot{V}_{mem} + V_{mem} = I_{in} + b + \\sigma\\zeta(t) where :math:`I_{in}(t)` is a :math:`N` vector containing a continuous input currents for each neuron; :math:`b` is a bias current for each neuron; :math:`\\sigma\\zeta(t)` is a white-noise process with standard deviation :math:`\\sigma` injected independently onto each neuron's membrane; and :math:`\\tau_{mem}` is the membrane time constant. :On spiking: When the membrane potential for neuron :math:`j`, :math:`V_{mem, j}` exceeds the threshold voltage :math:`V_{thr} = 0`, then the neuron emits a spike. .. math :: V_{mem, j} > V_{thr} \\rightarrow S_{rec,j} = 1 V_{mem, j} = V_{mem, j} - 1 Neurons therefore share a common resting potential of ``0.``, a firing threshold of ``0.``, and a subtractive reset of ``-1``. Neurons each have an optional bias current `.bias` (default: ``0.``). """
[docs] def __init__( self, shape: tuple = None, tau_mem: FloatVector = 0.1, bias: FloatVector = 0.0, dt: float = 1e-3, noise_std: float = 0.0, *args, **kwargs, ): """ Instantiate an LIF Neuron module Args: shape (tuple): Number of neuron-synapse pairs that will be created. Example: shape = (5,) tau_mem (FloatVector): An optional array with concrete initialisation data for the membrane time constants. If not provided, 100ms will be used by default. bias (FloatVector): An optional array with concrete initialisation data for the neuron bias currents. If not provided, 0.0 will be used by default. has_bias (bool): A flag indicating that the neurons should have a bias. Default: ``True``, neurons have a trainable bias. ``False``: Neurons have a bias fixed to zero. dt (float): The time step for the forward-Euler ODE solver. Default: 1ms noise_std (float): The std. dev. of the noise added to membrane state variables at each time-step. Default: 0.0 device: Defines the device on which the model will be processed. dtype: Defines the data type of the tensors saved as attributes. """ # - Check shape argument if np.size(shape) == 1: shape = (np.array(shape).item(),) if np.size(shape) > 1: raise ValueError("`shape` must have only one dimension for LIFNeuronTorch") # - Initialize super-class super().__init__( shape=shape, spiking_input=False, spiking_output=True, *args, **kwargs, ) # # - Determine arguments for building tensors # factory_kwargs = {"device": device, "dtype": dtype} self.n_neurons: P_int = rp.SimulationParameter(shape[0]) """ (int) Number of neurons """ # - Reset and thresholds self._v_thresh: float = 0.0 self._v_reset: float = -1.0 # - Intialise parameters self.noise_std: P_float = rp.SimulationParameter(noise_std) """ (float) Std. Dev. of noise injected into neurons on each time-step """ to_float_tensor = lambda x: torch.as_tensor(x, dtype=torch.float) self.tau_mem: P_tensor = rp.Parameter( tau_mem, family="taus", shape=[(self.n_neurons,), ()], init_func=lambda s: torch.ones(s) * 100e-3, cast_fn=to_float_tensor, ) self.bias: P_tensor = rp.Parameter( bias, shape=[(self.size_out,), ()], family="bias", init_func=torch.zeros, cast_fn=to_float_tensor, ) """ (Tensor) Neuron biases `(Nout,)` or `()` """ self.dt: P_float = rp.SimulationParameter(dt) """ (float) Simulation time-step in seconds """ self.vmem: P_tensor = rp.State( shape=self.n_neurons, init_func=torch.zeros, cast_fn=to_float_tensor ) """ (Tensor) Membrane potentials `(Nout,)` """ # - Attribute for recording state self._vmem_rec = None """ (torch.Tensor) Record of previous evolution """
[docs] def evolve( self, input_data: torch.Tensor, record: bool = False ) -> Tuple[Any, Any, Any]: # - Call super-class evolve output_data, states, record_dict = super().evolve(input_data, record) # - Fill record dictionary record_dict = ( { "vmem": self._vmem_rec, "spikes": self._spikes_rec, } if record else {} ) # - Return output return output_data, states, record_dict
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ forward method for processing data through this layer Adds synaptic inputs to the synaptic states and mimics the Leaky Integrate and Fire dynamics Args: data (torch.Tensor): Data takes the shape of (batch, time_steps, n_neurons) Raises: ValueError: Input has wrong neuron dimensions. Returns: torch.Tensor: Out of spikes with the shape (batch, time_steps, n_neurons) """ # - Validate data shape n_batches, time_steps, n_neurons = data.shape if n_neurons != self.size_out: raise ValueError( f"Input has wrong neuron dimensions. It is {n_neurons}, must be {self.size_out}" ) # - Expand state out by batch dimension vmem = torch.ones(n_batches, self.n_neurons).to(data.device) * self.vmem alpha = self.dt / self.tau_mem step_pwl = StepPWL.apply noise_std = self.noise_std # - Initialise output raster and state record out_spikes = torch.zeros(n_batches, self.n_neurons, device=data.device) self._vmem_rec = torch.zeros(data.shape, device=data.device) self._spikes_rec = torch.zeros( n_batches, time_steps, self.n_neurons, device=data.device ) # - Loop over time for t in range(time_steps): # - Update membrane potential dvmem = data[:, t, :] + self.bias - vmem vmem = ( vmem + alpha * dvmem + torch.randn(vmem.shape, device=data.device) * noise_std ) # - Compute spikes and reset out_spikes = step_pwl(vmem) vmem = vmem - out_spikes # - Record state self._vmem_rec[:, t, :] = vmem self._spikes_rec[:, t] = out_spikes # - Only retain state for first neuron self.vmem = vmem[0].detach() self._vmem_rec.detach_() return self._spikes_rec