Source code for nn.modules.sinabs.lif_exodus

"""
Implement a LIF Module, using an Exodus backend

Exodus is an accelerated CUDA-based simulator for LIF-like neuron dynamics, supporting gradient calculations.

This package implements the modules :py:class:`.LIFExodus`, :py:class:`.ExpSynExodus` and :py:class:`.LIFMembraneExodus`.
"""

from rockpool.nn.modules.torch.lif_torch import LIFBaseTorch
import torch
import warnings

from rockpool.typehints import *
from rockpool.parameters import Constant

from rockpool.graph import GraphModuleBase

from rockpool.utilities.backend_management import (
    backend_available,
    missing_backend_shim,
)

if backend_available("sinabs"):
    from sinabs.activation import Heaviside, SingleExponential

    if backend_available("sinabs.exodus"):
        from sinabs.exodus.spike import IntegrateAndFire
        from sinabs.exodus.leaky import LeakyIntegrator
    else:
        IntegrateAndFire = missing_backend_shim("IntegrateAndFire", "sinabs.exodus")
        LeakyIntegrator = missing_backend_shim("LeakyIntegrator", "sinabs.exodus")

else:
    Heaviside = missing_backend_shim("Heaviside", "sinabs")
    SingleExponential = missing_backend_shim("SingleExponential", "sinabs")


__all__ = ["LIFExodus", "LIFMembraneExodus", "LIFSlayer", "ExpSynExodus"]


[docs]class LIFExodus(LIFBaseTorch):
[docs] def __init__( self, shape: tuple, tau_mem: P_float = 0.02, tau_syn: P_float = 0.05, threshold: P_float = 1.0, learning_window: P_float = 0.5, bias: P_float = 0.0, has_rec: bool = False, noise_std: P_float = 0.0, *args, **kwargs, ): """ Instantiate an LIF module using the Exodus backend Uses the Exodus accelerated CUDA module to implement an LIF neuron. A CUDA device is required to instantiate this module. The output of evolving this module is the neuron spike events; synaptic currents and membrane potentials are available using the ``record = True`` argument to :py:meth:`~.LIFExodus.evolve`. Warnings: Exodus does not currently support training thresholds. Exodus does not support noise injection. Examples: Instantitate an LIF module with 2 neurons, with 2 synapses each (4 input channels). >>> mod = LIFExodus((4, 2)) Specify the membrane and synapse time constants, as well as time-step ``dt``. >>> mod = LIFExodus((4, 2), tau_mem = 30e-3, tau_syn = 10e-3, dt = 10e-3) Pass the model and data to the same cuda device, since it is required to use CUDA on this module. >>> data = torch.ones((1, 10, 4)) >>> device = 'cuda: 1' >>> mod.to(device) >>> data = data.to(device) >>> output = mod(data) Args: shape (tuple): The shape of this module tau_syn (float): An optional array with concrete initialisation data for the synaptic time constants. If not provided, 50ms will be used by default. tau_mem (float): An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default. bias (float): threshold (float): An optional array specifying the firing threshold of each neuron. If not provided, ``1.`` will be used by default. learning_window (float): Cutoff value for the surrogate gradient. Default: 0.5 dt (float): Time step in seconds. Default: 1 ms. """ if has_rec: raise ValueError("`LIFExodus` does not support recurrent weights.") if noise_std != 0.0: raise ValueError("`LIFExodus` does not support injected noise.") # - Initialise superclass super().__init__( shape=shape, tau_syn=tau_syn, tau_mem=tau_mem, threshold=Constant(threshold), bias=bias, has_rec=False, noise_std=0.0, learning_window=learning_window, *args, **kwargs, ) # - Assign the surrogate gradient function self.spike_generation_fn = Heaviside(self.learning_window) # - Check that CUDA is available if not torch.cuda.is_available(): raise EnvironmentError("CUDA is required for exodus-backed modules.")
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ forward method for processing data through this layer Adds synaptic inputs to the synaptic states and mimics the Leaky Integrate and Fire dynamics Args: data (torch.Tensor): Data takes the shape of (batch, time_steps, n_synapses) Returns: torch.Tensor: Out of spikes with the shape (batch, time_steps, n_neurons) """ # - Replicate data and states out by batches data, (vmem, isyn, spikes) = self._auto_batch( data, (self.vmem, self.isyn, self.spikes) ) # - Get input data size (n_batches, time_steps, n_connections) = data.shape # - Broadcast parameters to full size for this module beta = self.beta.expand((n_batches, self.n_neurons, self.n_synapses)).flatten() alpha = self.alpha.expand((n_batches, self.n_neurons)).flatten().contiguous() membrane_subtract = self.threshold.expand((n_batches, self.n_neurons)).flatten() threshold = ( self.threshold.expand((n_batches, self.n_neurons)).flatten().contiguous() ) # Bring data into format expected by exodus: (batches*neurons*synapses, timesteps) data = data.movedim(1, -1).flatten(0, -2) # Decay data by one timestep to match xylo behavior data = beta.unsqueeze(-1) * data # Synaptic dynamics: Calculate I_syn and bring to shape # (batches*neurons, synapses, timesteps) isyn_exodus = LeakyIntegrator.apply( data.contiguous(), # Input beta.contiguous(), # beta isyn.flatten().contiguous(), # initial state ).reshape(-1, self.n_synapses, time_steps) # Add bias to isyn_exodus, to be added onto the membrane bias = self.bias.reshape((1, -1, 1, 1)) bias = ( bias.expand((n_batches, self.n_neurons, self.n_synapses, time_steps)) .flatten(0, 1) .contiguous() ) isyn_with_bias = isyn_exodus + bias # Membrane dynamics: Calculate v_mem spikes, vmem_exodus = IntegrateAndFire.apply( isyn_with_bias.sum(1).contiguous(), # input alpha.contiguous(), # alpha vmem.flatten().contiguous(), # init state threshold, # threshold membrane_subtract.contiguous(), # membrane subtract None, # threshold low self.spike_generation_fn, None if torch.isinf(self.max_spikes_per_dt) else self.max_spikes_per_dt, ) # Subtract spikes from Vmem as exodus subtracts them starting from the next timestep vmem_exodus.data = vmem_exodus.data - spikes.data * threshold.unsqueeze(-1) # Bring states to rockpool dimensions isyn_exodus = ( isyn_exodus.reshape(n_batches, self.n_neurons, self.n_synapses, time_steps) .movedim(-1, 1) .to(data.device) ) vmem_exodus = ( vmem_exodus.reshape(n_batches, self.n_neurons, time_steps) .movedim(-1, 1) .to(data.device) ) spikes = ( spikes.reshape(n_batches, self.n_neurons, time_steps) .movedim(-1, 1) .to(data.device) ) self._record_dict["vmem"] = vmem_exodus self._record_dict["isyn"] = isyn_exodus self._record_dict["spikes"] = spikes self.vmem = vmem_exodus[0, -1].detach() self.isyn = isyn_exodus[0, -1].detach() self.spikes = spikes[0, -1].detach() return self._record_dict["spikes"]
[docs]class ExpSynExodus(LIFBaseTorch):
[docs] def __init__( self, shape: tuple, tau: P_float = 0.05, noise_std: P_float = 0.0, dt: float = 1e-3, *args, **kwargs, ): """ Instantiate an exponential synapse module using the Exodus backend Uses the Exodus accelerated CUDA module to implement an exponential synapse. A CUDA device is required to instantiate this module. The output of evolving this module is the synaptic currents. Warning: Exodus does not support noise injection. Examples: Instantitate an exponential synapse module with 2 synapses. >>> mod = LIFExodus(2) Specify the synaptic time constants, as well as time-step ``dt``. >>> mod = LIFExodus(2, tau_syn = 10e-3, dt = 10e-3) Specify multiple synaptic time constants. >>> mod = LIFExodus(2, tau_syn = [10e-3, 20e-3]) Pass the model and data to the same cuda device, since it is required to use CUDA on this module. >>> data = torch.ones((1, 10, 4)) >>> device = 'cuda: 1' >>> mod.to(device) >>> data = data.to(device) >>> output = mod(data) Args: shape (tuple): The shape of this module tau_syn (float): An optional array with concrete initialisation data for the synaptic time constants. If not provided, 50ms will be used by default. dt (float): Time step in seconds. Default: 1 ms. """ # - Remove unused parameters unused_arguments = ["threshold", "has_rec", "noise_std", "bias", "tau_mem"] test_args = [arg in kwargs for arg in unused_arguments] if any(test_args): error_args = [arg for (arg, t) in zip(unused_arguments, test_args) if t] raise TypeError( f"The argument(s) {error_args} is/are not used in ExpSynExodus." ) if noise_std != 0.0: raise ValueError("`ExpSynExodus` does not support injected noise.") # - Initialise superclass super().__init__( shape=shape, tau_syn=tau, has_rec=False, noise_std=0.0, dt=dt, *args, **kwargs, ) # - Remove LIFBaseTorch attributes that do not apply delattr(self, "tau_mem") delattr(self, "vmem") delattr(self, "threshold") delattr(self, "bias") delattr(self, "learning_window") delattr(self, "spikes") delattr(self, "spike_generation_fn") delattr(self, "max_spikes_per_dt") # - Check that CUDA is available if not torch.cuda.is_available(): raise EnvironmentError("CUDA is required for exodus-backed modules.")
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ forward method for processing data through this layer Adds inputs to the synaptic states Args: data (torch.Tensor): Data takes the shape of (batch, time_steps, n_synapses) Returns: torch.Tensor: Out of spikes with the shape (batch, time_steps, n_synapses) """ # - Replicate data and states out by batches data, (isyn,) = self._auto_batch(data, (self.isyn,)) # - Get input data size (n_batches, time_steps, n_connections) = data.shape # - Broadcast parameters to full size for this module beta = self.beta.expand((n_batches, self.n_neurons, self.n_synapses)).flatten() # Bring data into format expected by exodus: (batches*neurons*synapses, timesteps) data = data.movedim(1, -1).flatten(0, -2) # Decay data by one timestep to match xylo behavior data = beta.unsqueeze(-1) * data # Synaptic dynamics: Calculate I_syn and bring to shape # (batches*neurons, synapses, timesteps) isyn_exodus = LeakyIntegrator.apply( data.contiguous(), # Input beta.contiguous(), # beta isyn.flatten().contiguous(), # initial state ).reshape(-1, self.n_synapses, time_steps) # Bring states to rockpool dimensions isyn_exodus = ( isyn_exodus.reshape(n_batches, self.n_neurons, self.n_synapses, time_steps) .movedim(-1, 1) .to(data.device) ) # Save synaptic currents and return self._record_dict["isyn"] = isyn_exodus.reshape( n_batches, time_steps, self.size_out ) self.isyn = isyn_exodus[0, -1].detach() return self._record_dict["isyn"]
[docs]class LIFMembraneExodus(LIFBaseTorch):
[docs] def __init__( self, shape: tuple, tau_syn: P_float = 0.05, tau_mem: P_float = 0.02, bias: P_float = 0.0, *args, **kwargs, ): """ Instantiate a module implementing an LIF membrane using the Exodus backend Uses the Exodus accelerated CUDA module to implement an LIF neuron membrane. A CUDA device is required to instantiate this module. The output of evolving this module is the neuron membrane potentials; synaptic currents are available using the ``record = True`` argument to :py:meth:`~.LIFExodus.evolve`. Warnings: Exodus does not support noise injection. Examples: Instantitate an LIF membrane module with 2 neurons, with 2 synapses each (4 input channels). >>> mod = LIFMembraneExodus((4, 2)) Specify the membrane and synapse time constants, as well as time-step ``dt``. >>> mod = LIFMembraneExodus((4, 2), tau_mem = 30e-3, tau_syn = 10e-3, dt = 10e-3) Pass the model and data to the same cuda device, since it is required to use CUDA on this module. >>> data = torch.ones((1, 10, 4)) >>> device = 'cuda: 1' >>> mod.to(device) >>> data = data.to(device) >>> output = mod(data) Args: shape (tuple): The shape of this module tau_syn (float): An optional array with concrete initialisation data for the synapse time constants. If not provided, 50ms will be used by default. tau_mem (float): An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default. dt (float): Time-step of this module in seconds. Default: 1 ms. """ # - Remove unused parameters unused_arguments = ["threshold", "has_rec", "noise_std"] test_args = [arg in kwargs for arg in unused_arguments] if any(test_args): error_args = [arg for (arg, t) in zip(unused_arguments, test_args) if t] raise TypeError( f"The argument(s) {error_args} is/are not used in LIFMembraneExodus." ) # - Initialise superclass super().__init__( shape=shape, tau_mem=tau_mem, tau_syn=tau_syn, bias=bias, has_rec=False, noise_std=0.0, *args, **kwargs, ) # - Remove LIFBaseTorch attributes that do not apply delattr(self, "threshold") delattr(self, "learning_window") delattr(self, "spikes") delattr(self, "spike_generation_fn") delattr(self, "max_spikes_per_dt") # - Check that CUDA is available if not torch.cuda.is_available(): raise EnvironmentError("CUDA is required for exodus-backed modules.")
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ forward method for processing data through this layer Adds synaptic inputs to the synaptic states and mimics the Leaky Integrate and Fire dynamics Args: data (torch.Tensor): Data takes the shape of (batch, time_steps, n_synapses) Returns: torch.Tensor: Out of spikes with the shape (batch, time_steps, n_neurons) """ # - Replicate data and states out by batches data, (vmem, isyn) = self._auto_batch(data, (self.vmem, self.isyn)) # - Get input data size (n_batches, time_steps, n_connections) = data.shape # - Broadcast parameters to full size for this module beta = self.beta.expand((n_batches, self.n_neurons, self.n_synapses)).flatten() alpha = self.alpha.expand((n_batches, self.n_neurons)).flatten().contiguous() # Bring data into format expected by exodus: (batches*neurons*synapses, timesteps) data = data.movedim(1, -1).flatten(0, -2) # Decay data by one timestep to match xylo behavior data = beta.unsqueeze(-1) * data # Synaptic dynamics: Calculate I_syn and bring to shape # (batches*neurons, synapses, timesteps) isyn_exodus = LeakyIntegrator.apply( data.contiguous(), # Input beta.contiguous(), # beta isyn.flatten().contiguous(), # initial state ).reshape(-1, self.n_synapses, time_steps) # Add bias to isyn_exodus, to be added onto the membrane bias = self.bias.reshape((1, -1, 1, 1)) bias = ( bias.expand((n_batches, self.n_neurons, self.n_synapses, time_steps)) .flatten(0, 1) .contiguous() ) isyn_exodus = isyn_exodus + bias # Inteagrate onto a membrane vmem_exodus = LeakyIntegrator.apply( isyn_exodus.sum(1).contiguous(), # input alpha.contiguous(), # alpha vmem.flatten().contiguous(), # initial state ) # Bring states to rockpool dimensions isyn_exodus = ( isyn_exodus.reshape(n_batches, self.n_neurons, self.n_synapses, time_steps) .movedim(-1, 1) .to(data.device) ) vmem_exodus = ( vmem_exodus.reshape(n_batches, self.n_neurons, time_steps) .movedim(-1, 1) .to(data.device) ) self._record_dict["vmem"] = vmem_exodus self._record_dict["isyn"] = isyn_exodus self.vmem = vmem_exodus[0, -1].detach() self.isyn = isyn_exodus[0, -1].detach() return vmem_exodus
[docs] def as_graph(self) -> GraphModuleBase: raise NotImplementedError
class LIFSlayer(LIFExodus): """DEPRECATED: An LIF module with an Exodus backend""" def __init__(self, *args, **kwargs): """ Instantiate an LIF module with an Exodus backend Warnings: This module is deprecated. Use :py:class:`LIFExodus` instead. """ warnings.warn( "This module is deprecated. Use `LIFExodus` instead.", DeprecationWarning ) super().__init__(*args, **kwargs)