Source code for nn.modules.torch.updown_torch

"""
Feedforward layer that converts each analogue input channel to one spiking up and one down channel
"""

from typing import Optional, Union, Tuple, Any
import numpy as np
from rockpool.nn.modules.torch.torch_module import TorchModule
import torch
import rockpool.parameters as rp
from rockpool.typehints import FloatVector, P_float, P_tensor

__all__ = ["UpDownTorch"]


class StepPWL(torch.autograd.Function):
    """
    Heaviside step function with piece-wise linear derivative to use as spike-generation surrogate

    :param torch.Tensor data: Input data
    :param float thr: Threshold value

    :return torch.Tensor: output value and gradient function
    """

    @staticmethod
    def forward(ctx, data):
        ctx.save_for_backward(data)
        return torch.clamp(torch.floor(data), 0)

    @staticmethod
    def backward(ctx, grad_output):
        (data,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[data < 0.5] = 0
        # - Since there are two inputs, we need to give two outputs to backpropagate.
        return grad_input


## - UpDownTorch - Class: Define a spiking feedforward layer to convert analogue inputs to up and down channels
[docs]class UpDownTorch(TorchModule): """ Feedforward layer that converts each analogue input channel to one spiking up and one spiking down channel. This module orients itself on the ADC implementation in 'A Neuromorphic Event-Based Neural Recording System for Smart Brain-Machine-Interfaces', Corradi et al. 2015. While following the same idea, the dynamics and non-idealities of the module are not modeled, instead this module strives to be an nominal implementation of the idea of an up-down ADM. The spike generation is dependent on whether the input value surpasses/falls below the up/down threshold relative of the reference value. If a threshold is reached, a spike will be emitted and the threshhold value added (in case of an up spike) resp. subtracted from the refernece value. This module also allows for setting a refractory period, which is activated after a spike was emitted on either of the output channels, during which further spike emitting is supressed. """ ## - Constructor
[docs] def __init__( self, shape: tuple = None, thr_up: Optional[FloatVector] = 1e-3, thr_down: Optional[FloatVector] = 1e-3, n_ref_steps: int = 0, repeat_output: int = 1, dt: float = 1e-3, device=None, dtype=None, *args, **kwargs, ): """ Instantiate a spiking feedforward layer to convert analogue inputs to up and down spike channels. Args: shape (tuple): A single dimension ``(N_in,)``, which defines the number of input channels. The output is always given as ``N_out = 2 * N_in``. thr_up (Optional[FloatVector]): Thresholds for creating up-spikes. Default: ``0.001`` thr_down (Optional[FloatVector]): Thresholds for creating down-spikes. Default: ``0.001`` n_ref_steps (float): Determines the duration of the refractory period as multiple of `dt` (`t_ref=n_ref_steps*dt`). During the refractory period the module doesn't emit any spikes. Default: ``0`` repeat_output (int): Repeat each output spike x times. dt (float): The time step for the forward-Euler ODE solver in seconds. Default: ``1ms`` device: Defines the device on which the model will be processed. dtype: Defines the data type of the tensors saved as attributes. device: Defines the device on which the module will be processed. dtype: Defines the data type of the tensors saved as attributes. """ if np.size(shape) == 1: shape_in = np.array(shape).item() shape = (shape_in, 2 * shape_in) else: raise ValueError("`shape` must be a one-element tuple `(Nin,)`.") # - Call super constructor super().__init__( shape=shape, spiking_input=False, spiking_output=True, *args, **kwargs, ) # - Default tensor construction parameters self._factory_kwargs = {"device": device, "dtype": dtype} # - Store layer parameters self.repeat_output: P_float = rp.SimulationParameter(repeat_output) self.n_ref_steps: P_float = rp.SimulationParameter(n_ref_steps) if np.size(thr_up) == 1: thr_up = torch.ones((1, self.size_in), **self._factory_kwargs) * thr_up else: thr_up = thr_up.view(1, -1) self.thr_up: P_tensor = rp.Parameter( thr_up, family="thresholds", ) """ (Tensor) Thresholds for creating up-spikes `(N_in,)` """ if np.size(thr_down) == 1: thr_down = torch.ones((1, self.size_in), **self._factory_kwargs) * thr_down else: thr_down = thr_up.view(1, -1) self.thr_down: P_tensor = rp.Parameter( thr_down, family="thresholds", ) """ (Tensor) Thresholds for creating down-spikes `(N_in,)` """
[docs] def evolve( self, input_data: torch.Tensor, record: bool = False, ) -> Tuple[Any, Any, Any]: # - Evolve with superclass evolution output_data, states, _ = super().evolve(input_data, record) # - Build state record record_dict = ( { "analog_value": self._analog_value_rec, } if record else {} ) return output_data, states, record_dict
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ forward method for processing data through this layer Convert each analog input channel to an up and down spike channel. Args: data (torch.Tensor): Data takes the shape of `(batch, time_steps, n_channels)` Raises: ValueError: Input has wrong input channel dimension. Returns: torch.Tensor: Output of spikes with the shape `(batch, time_steps, 2*n_channels)`, where the `2*n`-th output channel the up channel and the `(2*n + 1)`-th output channel the down channel of the `n`-th input channel are. """ n_batches, time_steps, n_channels = data.shape if n_channels != self.size_in: raise ValueError( "Input has wrong input channel dimension. It is {}, must be {}".format( n_channels, self.size_in ) ) # - Extend thresholds by batches thr_up = torch.ones(n_batches, 1) @ self.thr_up thr_down = torch.ones(n_batches, 1) @ self.thr_down """ Counter, for how many steps of dt is the module still in refractory period. Has to be counted for each batch and channel individually. """ remaining_ref_steps = torch.zeros(n_batches, n_channels) # - Reference value from where we observe whether the signal surpasses any thresholds analog_value = data[:, 0, :].detach() step_pwl = StepPWL.apply # - Set up state record and output self._analog_value_rec = torch.zeros( n_batches, time_steps, n_channels, **self._factory_kwargs ) out_spikes = torch.zeros( n_batches, time_steps, self.size_out, **self._factory_kwargs ) # - Loop over time for t in range(time_steps): # - Record the state self._analog_value_rec[:, t, :] = analog_value.detach() # - Get the difference between the last analog value saved diff_values = data[:, t, :] - analog_value # - Calculate the spike outputs up_channels = step_pwl(diff_values / thr_up) # - Enter the negative thr_down so that it checks for changes going below this threshold. down_channels = step_pwl(diff_values / (-thr_down)) if self.n_ref_steps > 0: # - Remove the spikes of all channels that are still in the refractory period up_channels[remaining_ref_steps > 0] = 0 down_channels[remaining_ref_steps > 0] = 0 # - Limit the amount of emitted spikes to 1, since the refractory period supresses all spikes after the first one up_channels[up_channels >= 1] = 1 down_channels[down_channels >= 1] = 1 # - Reset the refractory counter back to the full time when either an up or a down spike was emitted remaining_ref_steps[ (up_channels + down_channels) > 0 ] = self.n_ref_steps # - Set the reference value to the last input for all channels which are in refractory period analog_value[remaining_ref_steps > 0] = (data[:, t, :])[ remaining_ref_steps > 0 ] # - Count down the refractory counters remaining_ref_steps -= 1 else: # - Add (resp. subtract) the thresholds for every emitted spike analog_value = analog_value + up_channels * thr_up analog_value = analog_value - down_channels * thr_down # - Interleave up and down channels so we have 2*k as up and 2*k + 1 as down channel of the k-th input channel # - Multiply by repeat_output to get the desired multiple of spikes out_spikes[:, t, :] = self.repeat_output * torch.stack( (up_channels, down_channels), dim=2, ).view(n_batches, 2 * n_channels) return out_spikes