Source code for nn.modules.torch.exp_syn_torch

Implement a exponential synapse module, using a Torch backend

from typing import Optional, Tuple, Any, Union
import numpy as np
from rockpool.nn.modules.torch.torch_module import TorchModule
import torch
import rockpool.parameters as rp

import rockpool.typehints as rt

__all__ = ["ExpSynTorch"]

[docs]class ExpSynTorch(TorchModule): """ Exponential synapse module with a Torch backend This module implements a layer of exponential synapses, operating under the update equations .. math:: I_{syn} = I_{syn} + i(t) I_{syn} = I_{syn} * \exp(-dt / \tau) I_{syn} = I_{syn} + \sigma \zeta_t where :math:`i(t)` is the instantaneous input; :math:`\\tau` is the vector ``(N,)`` of time constants for each synapse in seconds; :math:`dt` is the update time-step in seconds; :math:`\\sigma` is the std. deviation after 1s of a Wiener process. """
[docs] def __init__( self, shape: Union[tuple, int], tau: Optional[rt.FloatVector] = None, noise_std: float = 0.0, dt: float = 1e-3, spiking_input: bool = True, spiking_output: bool = False, *args, **kwargs, ): """ Instantiate an exp. synapse module Args: shape (Union[int, tuple]): The number of synapses in this module ``(N,)``. tau (Optional[np.ndarray]): Concrete initialisation data for the time constants of the synapses, in seconds. Default: 10 ms individual for all synapses. noise_std (float): The std. dev after 1s of noise added independently to each synapse dt (float): The timestep of this module, in seconds. Default: 1 ms. """ # Initialize super class super().__init__( shape=shape, spiking_input=spiking_input, spiking_output=spiking_output, *args, **kwargs, ) # - To-float-tensor conversion utility to_float_tensor = lambda x: torch.as_tensor(x, dtype=torch.float) # - Initialise tau self.tau: rt.P_tensor = rp.Parameter( tau, shape=[(self.size_out,), ()], family="taus", init_func=lambda s: torch.ones(*s) * 10e-3, cast_fn=to_float_tensor, ) """ (torch.Tensor) Time constants of each synapse in seconds ``(N,)`` or ``()`` """ # - Initialise noise std dev self.noise_std: rt.P_tensor = rp.SimulationParameter( noise_std, cast_fn=to_float_tensor ) """ (float) Noise std. dev after 1 second """ # - Initialise state self.isyn: rt.P_tensor = rp.State( shape=(self.size_out,), init_func=lambda s: torch.zeros(*s), ) """ (torch.tensor) Synaptic current state for each synapse ``(1, N)`` """ # - Store dt self.dt: rt.P_float = rp.SimulationParameter(dt) """ (float) Simulation time-step in seconds """
[docs] def evolve( self, input_data: torch.Tensor, record: bool = False ) -> Tuple[Any, Any, Any]: # - Evolve the module output_data, _, _ = super().evolve(input_data, record) # - Return the result of evolution return output_data, self.state(), {}
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ Forward method for processing data through this layer Adds synaptic inputs to the synaptic states and mimics the exponential synapse dynamics Parameters ---------- data: Tensor Data takes the shape of (batch, time_steps, N) Returns ------- out: Tensor Out of spikes with the shape (batch, time_steps, N) """ # - Auto-batch over input data data, (isyn,) = self._auto_batch(data, (self.isyn,)) n_batches, time_steps, _ = data.shape # - Build a tensor to compute and return internal state isyn_ts = torch.zeros(data.shape, device=data.device) # - Compute decay factor beta = torch.exp(-self.dt / self.tau) noise_zeta = self.noise_std * torch.sqrt(torch.tensor(self.dt)) data = data + noise_zeta * torch.randn(data.shape, device=data.device) # - Loop over time for t in range(time_steps): isyn += data[:, t, :] isyn *= beta isyn_ts[:, t, :] = isyn # - Store the final state self.isyn = isyn[0].detach() # - Return the evolved synaptic current return isyn_ts