"""
Implement a LIF Module, using a Torch backend
Provides :py:class:`.LIFBaseTorch` base class for LIF torch modules, and :py:class:`.LIFTorch` module.
"""
from tempfile import gettempprefix
from typing import Union, Tuple, Callable, Optional, Any
import numpy as np
from rockpool.nn.modules.torch.torch_module import TorchModule
import torch
import torch.nn.functional as F
import torch.nn.init as init
import rockpool.parameters as rp
from rockpool.typehints import *
from rockpool.graph import (
GraphModuleBase,
as_GraphHolder,
LIFNeuronWithSynsRealValue,
LinearWeights,
)
__all__ = ["LIFTorch"]
class StepPWL(torch.autograd.Function):
"""
Heaviside step function with piece-wise linear surrogate to use as spike-generation surrogate
"""
@staticmethod
def forward(
ctx,
x,
threshold=torch.tensor(1.0),
window=torch.tensor(0.5),
max_spikes_per_dt=torch.tensor(2.0**16),
):
ctx.save_for_backward(x, threshold)
ctx.window = window
nr_spikes = ((x >= threshold) * torch.floor(x / threshold)).float()
clamp_bool = (nr_spikes > max_spikes_per_dt).float()
nr_spikes -= (nr_spikes - max_spikes_per_dt.float()) * clamp_bool
return nr_spikes
@staticmethod
def backward(ctx, grad_output):
x, threshold = ctx.saved_tensors
grad_x = grad_threshold = grad_window = grad_max_spikes_per_dt = None
mask = x >= (threshold - ctx.window)
if ctx.needs_input_grad[0]:
grad_x = grad_output / threshold * mask
if ctx.needs_input_grad[1]:
grad_threshold = -x * grad_output / (threshold**2) * mask
return grad_x, grad_threshold, grad_window, grad_max_spikes_per_dt
class PeriodicExponential(torch.autograd.Function):
"""
Subtract from membrane potential on reaching threshold
"""
@staticmethod
def forward(
ctx,
data,
threshold=1.0,
window=0.5,
max_spikes_per_dt=torch.tensor(2.0**16),
):
ctx.save_for_backward(data.clone())
ctx.threshold = threshold
ctx.window = window
ctx.max_spikes_per_dt = max_spikes_per_dt
nr_spikes = ((data >= threshold) * torch.floor(data / threshold)).float()
clamp_bool = (nr_spikes > max_spikes_per_dt).float()
nr_spikes -= (nr_spikes - max_spikes_per_dt.float()) * clamp_bool
return nr_spikes
@staticmethod
def backward(ctx, grad_output):
(membranePotential,) = ctx.saved_tensors
vmem_shifted = membranePotential - ctx.threshold / 2
nr_spikes_shifted = torch.clamp(
torch.div(vmem_shifted, ctx.threshold, rounding_mode="floor"),
max=ctx.max_spikes_per_dt - 1,
)
vmem_periodic = vmem_shifted - nr_spikes_shifted * ctx.threshold
vmem_below = vmem_shifted * (membranePotential < ctx.threshold)
vmem_above = vmem_periodic * (membranePotential >= ctx.threshold)
vmem_new = vmem_above + vmem_below
spikePdf = (
torch.exp(-torch.abs(vmem_new - ctx.threshold / 2) / ctx.window)
/ ctx.threshold
)
return (
grad_output * spikePdf,
grad_output * -spikePdf * membranePotential / ctx.threshold,
None,
None,
)
# - Surrogate functions to use in learning
def sigmoid(x: FloatVector, threshold: FloatVector) -> FloatVector:
"""
Sigmoid function
:param FloatVector x: Input value
:return FloatVector: Output value
"""
return torch.tanh(x + 1 - threshold) / 2 + 0.5
def decay_to_tau(dt, *decays):
return tuple([-(dt / torch.log(decay).to(decay.device)) for decay in decays])
def tau_to_decay(dt, *taus):
return tuple([torch.exp(-dt / tau).to(tau.device) for tau in taus])
def tau_to_bitshift(dt, *taus):
return tuple([-torch.log2(1 - torch.exp(-dt / tau)).to(tau.device) for tau in taus])
def bitshift_to_tau(dt, *dashes):
return tuple(
[-dt / torch.log(1 - 1 / (2**dash)).to(dash.device) for dash in dashes]
)
def decay_to_bitshift(*decays):
return tuple([-torch.log2(1 - decay).to(decay.device) for decay in decays])
def bitshift_to_decay(*dashes):
return tuple([(1 - 1 / (2**dash)).to(dash.device) for dash in dashes])
class LIFBaseTorch(TorchModule):
def __init__(
self,
shape: tuple,
leak_mode: str = "taus",
tau_mem: Optional[Union[FloatVector, P_float]] = None,
tau_syn: Optional[Union[FloatVector, P_float]] = None,
alpha: Optional[Union[FloatVector, P_float]] = None,
beta: Optional[Union[FloatVector, P_float]] = None,
dash_mem: Optional[Union[IntVector, P_float]] = None,
dash_syn: Optional[Union[IntVector, P_float]] = None,
bias: Optional[FloatVector] = None,
threshold: Optional[FloatVector] = None,
has_rec: P_bool = False,
w_rec: torch.Tensor = None,
noise_std: P_float = 0.0,
spike_generation_fn: torch.autograd.Function = StepPWL,
learning_window: P_float = 0.5,
max_spikes_per_dt: P_int = torch.tensor(2.0**16),
weight_init_func: Optional[
Callable[[Tuple], torch.tensor]
] = lambda s: init.kaiming_uniform_(torch.empty(s)),
dt: P_float = 1e-3,
*args,
**kwargs,
):
"""
Instantiate an LIF module
Note:
On instantiation, the user can specify how the decay parameters of the module are defined; either as time constant values (:py:attr:`.tau_mem` and :py:attr:`.tau_syn`), as decay factors (:py:attr:`.alpha` and :py:attr:`.beta`) or as bitshift values (:py:attr:`.dash_mem` and `.dash_syn`).
This is specifed using the ``leak_mode`` argument on initialisation.
By default, this is set to ``'taus'``, in which the time constants are direct parameters, which are trainable by default.
if ``'taus'``, :py:attr:`.tau_mem` and :py:attr:`.tau_syn` are used as model parameters
if ``'decays'``, :py:attr:`.alpha` and :py:attr:`.beta` are used as model parameters (:py:attr:`.alpha` and :py:attr:`.beta` are: :math:`\exp(-dt / \\tau_{mem}`) and :math:`\\exp(-dt / \\tau_{syn}`) respectively)
if ``'bitshifts'``, :py:attr:`.dash_mem` and `.dash_syn` are used as model parameters. :py:attr:`.dash_mem` and :py:attr:`.dash_syn` are the bitshift equivalent of decays, such that :math:`.alpha = 1-(1/(2**dash_mem))`
If decay parameters are passed as :py:func:`.Constant` in the instantiation of module they will be set to non-traianble parameters.
Args:
shape (tuple): Either a single dimension ``(Nout,)``, which defines a feed-forward layer of LIF modules with equal amounts of synapses and neurons, or two dimensions ``(Nin, Nout)``, which defines a layer of ``Nin`` synapses and ``Nout`` LIF neurons.
leak_mode (str): sets the training mode of time constants: Default: ``'taus'``. Must be one of ``{'taus', 'decays', 'bitshifts'}``
tau_mem (Optional[FloatVector]): An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default.
tau_syn (Optional[FloatVector]): An optional array with concrete initialisation data for the synaptic time constants. If not provided, 20ms will be used by default.
alpha (Optional[FloatVector]): An optional array with concrete initialisation data for the membrane decays. If not provided, 0.5 will be used by default.
beta (Optional[FloatVector]): An optional array with concrete initialisation data for the synaptic decays. If not provided, 0.5 will be used by default.
dash_mem (Optional[FloatVector]): An optional array with concrete initialisation data for the membrane bitshifts. If not provided, 1 will be used by default.
dash_syn (Optional[FloatVector]): An optional array with concrete initialisation data for the synaptic bitshifts. If not provided, 1 will be used by default.
bias (Optional[FloatVector]): An optional array with concrete initialisation data for the neuron bias currents. If not provided, ``0.0`` will be used by default.
threshold (FloatVector): An optional array specifying the firing threshold of each neuron. If not provided, ``1.`` will be used by default.
has_rec (bool): When ``True`` the module provides a trainable recurrent weight matrix. Default ``False``, module is feed-forward.
w_rec (torch.Tensor): If the module is initialised in recurrent mode, you can provide a concrete initialisation for the recurrent weights, which must be a matrix with shape ``(Nout, Nin)``. If the model is not initialised in recurrent mode, then you may not provide ``w_rec``.
noise_std (float): The std. dev. of the noise added to membrane state variables at each time-step. Default: ``0.0`` (no noise)
spike_generation_fn (Callable): Function to call for spike production. Usually simple threshold crossing. Implements the surrogate gradient function in the backward call. (StepPWL or PeriodicExponential).
learning_window (float): Cutoff value for the surrogate gradient.
max_spikes_per_dt (float): The maximum number of events that will be produced in a single time-step. Default: ``2**16``.
weight_init_func (Optional[Callable[[Tuple], torch.tensor]): The initialisation function to use when generating recurrent weights. Default: ``None`` (Kaiming initialisation)
dt (float): The time step for the forward-Euler ODE solver. Default: 1ms
"""
# - Check training mode
if leak_mode not in [
"taus",
"decays",
"bitshifts",
]:
raise ValueError(
"Training of time constants in `LIFTorch` neurons can be done only in one of the following modes: 'taus', 'decays', 'bitshifts'. `leak_mode` must be one of these values."
)
# - Check shape argument
if np.size(shape) == 1:
shape = (np.array(shape).item(), np.array(shape).item())
if np.size(shape) > 2:
raise ValueError(
"`shape` must be a one- or two-element tuple `(Nin, Nout)`."
)
# - Initialise superclass
super().__init__(
shape=shape,
spiking_input=True,
spiking_output=True,
*args,
**kwargs,
)
# - Initialise dummy parameters list
self._dummy_params = ()
self.leak_mode = rp.SimulationParameter(leak_mode)
""" (str) The mode by which leaks are determined for this module. """
self.n_neurons = self.size_out
self.n_synapses: P_int = shape[0] // shape[1]
""" (int) Number of input synapses per neuron """
self.dt: P_float = rp.SimulationParameter(dt)
""" (float) Euler simulator time-step in seconds"""
# - To-float-tensor conversion utility
to_float_tensor = lambda x: torch.as_tensor(x, dtype=torch.float)
# - Initialise recurrent weights
w_rec_shape = (self.size_out, self.size_in)
self._has_rec: bool = rp.SimulationParameter(has_rec)
if has_rec:
self.w_rec: P_tensor = rp.Parameter(
w_rec,
shape=w_rec_shape,
init_func=weight_init_func,
family="weights",
cast_fn=to_float_tensor,
)
""" (Tensor) Recurrent weights `(Nout, Nin)` """
else:
if w_rec is not None:
raise ValueError("`w_rec` may not be provided if `has_rec` is `False`")
self.w_rec: P_ndarray = rp.SimulationParameter(
torch.zeros((self.size_out, self.size_in))
)
self.noise_std: P_float = rp.SimulationParameter(noise_std)
""" (float) Noise std.dev. injected onto the membrane of each neuron during evolution """
if self.leak_mode == "taus":
if any([alpha, beta, dash_mem, dash_syn]):
raise ValueError(
"current leak mode is set to taus and only parameters from this family can be directly initilized (eg: tau_mem and tau_syn)"
)
self.tau_mem: P_tensor = rp.Parameter(
tau_mem,
family="taus",
shape=[(self.size_out,), ()],
init_func=lambda s: torch.ones(s) * 20e-3,
cast_fn=to_float_tensor,
)
""" (Tensor) Membrane time constants `(Nout,)` or `()` """
self.tau_syn: P_tensor = rp.Parameter(
tau_syn,
family="taus",
shape=[
(
self.size_out,
self.n_synapses,
),
(
1,
self.n_synapses,
),
(),
],
init_func=lambda s: torch.ones(s) * 20e-3,
cast_fn=to_float_tensor,
)
""" (Tensor) Synaptic time constants `(Nin,)` or `()` """
self._dummy_params = ("alpha", "beta", "dash_syn", "dash_mem")
elif self.leak_mode == "decays":
if any([tau_mem, tau_syn, dash_mem, dash_syn]):
raise ValueError(
"current leak mode is set to decays and only parameters from this family can be directly initilized (eg: alpha and beta)"
)
self.alpha: P_tensor = rp.Parameter(
alpha,
family="decays",
shape=[(self.size_out,), ()],
init_func=lambda s: torch.ones(s) * 0.5,
cast_fn=to_float_tensor,
)
""" (Tensor) Membrane decay factor `(Nout,)` or `()` """
self.beta: P_tensor = rp.Parameter(
beta,
family="decays",
shape=[
(
self.size_out,
self.n_synapses,
),
(
1,
self.n_synapses,
),
(),
],
init_func=lambda s: torch.ones(s) * 0.5,
cast_fn=to_float_tensor,
)
""" (Tensor) Synaptic decay factor `(Nin,)` or `()` """
self._dummy_params = ("tau_syn", "tau_mem", "dash_syn", "dash_mem")
elif self.leak_mode == "bitshifts":
if any([alpha, beta, tau_mem, tau_syn]):
raise ValueError(
"current leak mode is set to bitshifts and only parameters from this family can be directly initilized (eg: dash_mem and dash_syn)"
)
self.dash_mem: P_tensor = rp.Parameter(
dash_mem,
family="bitshifts",
shape=[(self.size_out,), ()],
init_func=lambda s: torch.ones(s),
cast_fn=to_float_tensor,
)
""" (Tensor) membrane bitshift in xylo `(Nout,)` or `()` """
self.dash_syn: P_tensor = rp.Parameter(
dash_syn,
family="bitshifts",
shape=[
(
self.size_out,
self.n_synapses,
),
(
1,
self.n_synapses,
),
(),
],
init_func=lambda s: torch.ones(s),
cast_fn=to_float_tensor,
)
""" (Tensor) synaptic bitshift in xylo `(Nout,)` or `()` """
self._dummy_params = ("alpha", "beta", "tau_syn", "tau_mem")
self.bias: P_tensor = rp.Parameter(
bias,
shape=[(self.size_out,), ()],
family="bias",
init_func=torch.zeros,
cast_fn=to_float_tensor,
)
""" (Tensor) Neuron biases `(Nout,)` or `()` """
self.threshold: P_tensor = rp.Parameter(
threshold,
shape=[(self.size_out,), ()],
family="thresholds",
init_func=torch.ones,
cast_fn=to_float_tensor,
)
""" (Tensor) Firing threshold for each neuron `(Nout,)` """
self.learning_window: P_tensor = rp.SimulationParameter(
learning_window,
cast_fn=to_float_tensor,
)
""" (float) Learning window cutoff for surrogate gradient function """
self.vmem: P_tensor = rp.State(
shape=self.size_out, init_func=torch.zeros, cast_fn=to_float_tensor
)
""" (Tensor) Membrane potentials `(Nout,)` """
self.isyn: P_tensor = rp.State(
shape=(self.size_out, self.n_synapses),
init_func=torch.zeros,
cast_fn=to_float_tensor,
)
""" (Tensor) Synaptic currents `(Nin,)` """
self.spikes: P_tensor = rp.State(
shape=self.size_out, init_func=torch.zeros, cast_fn=to_float_tensor
)
""" (Tensor) Spikes `(Nin,)` """
self.spike_generation_fn: P_Callable = rp.SimulationParameter(
spike_generation_fn.apply
)
""" (Callable) Spike generation function with surrograte gradient """
self.max_spikes_per_dt: P_float = rp.SimulationParameter(
max_spikes_per_dt, cast_fn=to_float_tensor
)
""" (float) Maximum number of events that can be produced in each time-step """
# - Placeholders for state recordings
self._record_dict = {}
self._record = False
def __getattr__(self, name: str) -> Union[torch.Tensor, TorchModule]:
"""
Overridden __getattr__ to manage access to decay parameters
"""
if name in object.__getattribute__(self, "_dummy_params"):
all_TCs = self._get_all_leak_params()
return all_TCs[name]
return super().__getattr__(name)
def evolve(
self, input_data: torch.Tensor, record: bool = False
) -> Tuple[Any, Any, Any]:
# - Keep track of "record" flag for use by `forward` method
self._record = record
# - Evolve with superclass evolution
output_data, _, _ = super().evolve(input_data, record)
# - Obtain state record dictionary
record_dict = self._record_dict if record else {}
# - Clear record in order to avoid non-leaf tensors hanging around
self._record_dict = {}
return output_data, self.state(), record_dict
def as_graph(self) -> GraphModuleBase:
# - Get neuron parameters for export
tau_mem = self.tau_mem.expand((self.size_out,)).flatten().detach().cpu().numpy()
tau_syn = (
self.tau_syn.expand((self.size_out, self.n_synapses))
.flatten()
.detach()
.cpu()
.numpy()
)
threshold = (
self.threshold.expand((self.size_out,)).flatten().detach().cpu().numpy()
)
bias = self.bias.expand((self.size_out,)).flatten().detach().cpu().numpy()
# - Generate a GraphModule for the neurons
neurons = LIFNeuronWithSynsRealValue._factory(
self.size_in,
self.size_out,
f"{type(self).__name__}_{self.name}_{id(self)}",
self,
tau_mem,
tau_syn,
threshold,
bias,
self.dt,
)
# - Include recurrent weights if present
if self._has_rec:
# - Weights are connected over the existing input and output nodes
w_rec_graph = LinearWeights(
neurons.output_nodes,
neurons.input_nodes,
f"{type(self).__name__}_recurrent_{self.name}_{id(self)}",
self,
self.w_rec.detach().cpu().numpy(),
)
# - Return a graph containing neurons and optional weights
return as_GraphHolder(neurons)
def _get_all_leak_params(self):
"""
Calculate and return all decay parameters, depending on leak mode
"""
if self.leak_mode == "taus":
# - Compute decay parameters based on taus
tau_mem = getattr(
self, "tau_mem", torch.tensor(torch.nan).repeat(self.size_out)
)
tau_syn = getattr(
self, "tau_syn", torch.tensor(torch.nan).repeat(self.size_out)
)
alpha, beta = tau_to_decay(self.dt, tau_mem, tau_syn)
dash_mem, dash_syn = tau_to_bitshift(self.dt, tau_mem, tau_syn)
elif self.leak_mode == "decays":
# - Compute decay parameters based on decay constants
alpha = getattr(
self, "alpha", torch.tensor(torch.nan).repeat(self.size_out)
)
beta = getattr(self, "beta", torch.tensor(torch.nan).repeat(self.size_out))
tau_mem, tau_syn = decay_to_tau(self.dt, alpha, beta)
dash_mem, dash_syn = decay_to_bitshift(alpha, beta)
elif self.leak_mode == "bitshifts":
# - Compute decay parameters based on bitshift values
dash_mem = getattr(
self, "dash_mem", torch.tensor(torch.nan).repeat(self.size_out)
)
dash_syn = getattr(
self, "dash_syn", torch.tensor(torch.nan).repeat(self.size_out)
)
tau_mem, tau_syn = bitshift_to_tau(self.dt, dash_mem, dash_syn)
alpha, beta = bitshift_to_decay(dash_mem, dash_syn)
# - Return all parameters
return {
"tau_mem": tau_mem,
"tau_syn": tau_syn,
"alpha": alpha,
"beta": beta,
"dash_mem": dash_mem,
"dash_syn": dash_syn,
}
def __setattr__(self, key, value: Any):
"""
Overridden __setattr__ to manage access to decay parameters
"""
if hasattr(self, "_dummy_params") and key in self._dummy_params:
self._set_leak_param(key, value)
return super().__setattr__(key, value)
def _set_leak_param(self, name, value):
"""
Set the value of a named decay parameter, depending on leak mode
"""
if self.leak_mode == "taus":
# - Compute tau from `name`
if name == "alpha":
return setattr(self, "tau_mem", tau_to_decay(self.dt, value)[0])
elif name == "beta":
return setattr(self, "tau_syn", tau_to_decay(self.dt, value)[0])
elif name == "dash_mem":
return setattr(self, "tau_mem", tau_to_bitshift(self.dt, value)[0])
elif name == "dash_syn":
return setattr(self, "tau_syn", tau_to_bitshift(self.dt, value)[0])
elif self.leak_mode == "decays":
if name == "tau_mem":
return setattr(self, "alpha", decay_to_tau(self.dt, value)[0])
elif name == "tau_syn":
return setattr(self, "beta", decay_to_tau(self.dt, value)[0])
elif name == "dash_mem":
return setattr(self, "alpha", decay_to_bitshift(value)[0])
elif name == "dash_syn":
return setattr(self, "beta", decay_to_bitshift(value)[0])
elif self.leak_mode == "bitshifts":
if name == "tau_mem":
return setattr(self, "dash_mem", tau_to_bitshift(self.dt, value)[0])
elif name == "tau_syn":
return setattr(self, "dash_syn", tau_to_bitshift(self.dt, value)[0])
elif name == "alpha":
return setattr(self, "dash_mem", decay_to_bitshift(value)[0])
elif name == "beta":
return setattr(self, "dash_syn", decay_to_bitshift(value)[0])
[docs]class LIFTorch(LIFBaseTorch):
"""
A leaky integrate-and-fire spiking neuron model with a Torch backend
This module implements the update equations:
.. math ::
I_{syn} += S_{in}(t) + S_{rec} \\cdot W_{rec}
I_{syn} *= \\exp(-dt / \\tau_{syn})
V_{mem} *= \\exp(-dt / \\tau_{mem})
V_{mem} += I_{syn} + b + \\sigma \\zeta(t)
where :math:`S_{in}(t)` is a vector containing ``1`` (or a weighed spike) for each input channel that emits a spike at time :math:`t`; :math:`b` is a :math:`N` vector of bias currents for each neuron; :math:`\\sigma\\zeta(t)` is a Wiener noise process with standard deviation :math:`\\sigma` after 1s; and :math:`\\tau_{mem}` and :math:`\\tau_{syn}` are the membrane and synaptic time constants, respectively. :math:`S_{rec}(t)` is a vector containing ``1`` for each neuron that emitted a spike in the last time-step. :math:`W_{rec}` is a recurrent weight matrix, if recurrent weights are used. :math:`b` is an optional bias current per neuron (default 0.).
:On spiking:
When the membrane potential for neuron :math:`j`, :math:`V_{mem, j}` exceeds the threshold voltage :math:`V_{thr}`, then the neuron emits a spike. The spiking neuron subtracts its own threshold on reset.
.. math ::
V_{mem, j} > V_{thr} \\rightarrow S_{rec,j} = 1
V_{mem, j} = V_{mem, j} - V_{thr}
Neurons therefore share a common resting potential of ``0``, have individual firing thresholds, and perform subtractive reset of ``-V_{thr}``.
"""
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor:
"""
forward method for processing data through this layer
Adds synaptic inputs to the synaptic states and mimics the Leaky Integrate and Fire dynamics
Args:
input_data (torch.Tensor): Data takes the shape of (batch, time_steps, n_synapses)
Returns:
torch.Tensor: Out of spikes with the shape (batch, time_steps, Nout)
"""
# - Auto-batch over input data
input_data, (vmem, spikes, isyn) = self._auto_batch(
input_data,
(self.vmem, self.spikes, self.isyn),
(
(self.size_out,),
(self.size_out,),
(self.size_out, self.n_synapses),
),
)
n_batches, n_timesteps, _ = input_data.shape
# - Reshape data over separate input synapses
input_data = input_data.reshape(
n_batches, n_timesteps, self.size_out, self.n_synapses
)
# - Set up state record and output
if self._record:
self._record_dict["vmem"] = torch.zeros(
n_batches, n_timesteps, self.size_out
)
self._record_dict["isyn"] = torch.zeros(
n_batches, n_timesteps, self.size_out, self.n_synapses
)
self._record_dict["irec"] = torch.zeros(
n_batches, n_timesteps, self.size_out, self.n_synapses
)
self._record_dict["spikes"] = torch.zeros(
n_batches, n_timesteps, self.size_out, device=input_data.device
)
noise_zeta = self.noise_std * torch.sqrt(torch.tensor(self.dt))
# - Generate membrane noise trace
noise_ts = noise_zeta * torch.randn(
(n_batches, n_timesteps, self.size_out), device=vmem.device
)
# - Loop over time
for t in range(n_timesteps):
# Integrate synaptic input
isyn = isyn + input_data[:, t]
# - Apply spikes over the recurrent weights
if hasattr(self, "w_rec"):
irec = F.linear(spikes, self.w_rec.T).reshape(
n_batches, self.size_out, self.n_synapses
)
isyn = isyn + irec
# Decay synaptic and membrane state
vmem *= self.alpha.to(vmem.device)
isyn *= self.beta.to(isyn.device)
# Integrate membrane state and apply noise
vmem = vmem + isyn.sum(2) + noise_ts[:, t, :] + self.bias
# - Spike generation
spikes = self.spike_generation_fn(
vmem, self.threshold, self.learning_window, self.max_spikes_per_dt
)
# - Apply subtractive membrane reset
vmem = vmem - spikes * self.threshold
# - Maintain state record
if self._record:
self._record_dict["vmem"][:, t] = vmem
self._record_dict["isyn"][:, t] = isyn
if hasattr(self, "w_rec"):
self._record_dict["irec"][:, t] = irec
# - Maintain output spike record
self._record_dict["spikes"][:, t] = spikes
# - Update states
self.vmem = vmem[0].detach()
self.isyn = isyn[0].detach()
self.spikes = spikes[0].detach()
# - Return output
return self._record_dict["spikes"]