from typing import Optional, Union, Tuple, Any
import numpy as np
from rockpool.nn.modules.torch.torch_module import TorchModule
import torch
import rockpool.parameters as rp
from rockpool.typehints import FloatVector, P_float, P_tensor

__all__ = ["UpDownTorch"]

class StepPWL(torch.autograd.Function):
    Heaviside step function with piece-wise linear derivative to use as spike-generation surrogate

    :param torch.Tensor data: Input data
    :param float thr: Threshold value

    :return torch.Tensor: output value and gradient function

    def forward(ctx, data):
        return torch.clamp(torch.floor(data), 0)

    def backward(ctx, grad_output):
        (data,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[data < 0.5] = 0
        # - Since there are two inputs, we need to give two outputs to backpropagate.
        return grad_input

## - UpDownTorch - Class: Define a spiking feedforward layer to convert analogue inputs to up and down channels
[docs]class UpDownTorch(TorchModule): """ Feedforward layer that converts each analogue input channel to one spiking up and one spiking down channel. This module orients itself on the ADC implementation in 'A Neuromorphic Event-Based Neural Recording System for Smart Brain-Machine-Interfaces', Corradi et al. 2015. While following the same idea, the dynamics and non-idealities of the module are not modeled, instead this module strives to be an nominal implementation of the idea of an up-down ADM. The spike generation is dependent on whether the input value surpasses/falls below the up/down threshold relative of the reference value. If a threshold is reached, a spike will be emitted and the threshhold value added (in case of an up spike) resp. subtracted from the refernece value. This module also allows for setting a refractory period, which is activated after a spike was emitted on either of the output channels, during which further spike emitting is supressed. """ ## - Constructor
[docs] def __init__( self, shape: tuple = None, thr_up: Optional[FloatVector] = 1e-3, thr_down: Optional[FloatVector] = 1e-3, n_ref_steps: int = 0, repeat_output: int = 1, dt: float = 1e-3, device=None, dtype=None, *args, **kwargs, ): """ Instantiate a spiking feedforward layer to convert analogue inputs to up and down spike channels. Args: shape (tuple): A single dimension ``(N_in,)``, which defines the number of input channels. The output is always given as ``N_out = 2 * N_in``. thr_up (Optional[FloatVector]): Thresholds for creating up-spikes. Default: ``0.001`` thr_down (Optional[FloatVector]): Thresholds for creating down-spikes. Default: ``0.001`` n_ref_steps (float): Determines the duration of the refractory period as multiple of `dt` (`t_ref=n_ref_steps*dt`). During the refractory period the module doesn't emit any spikes. Default: ``0`` repeat_output (int): Repeat each output spike x times. dt (float): The time step for the forward-Euler ODE solver in seconds. Default: ``1ms`` device: Defines the device on which the model will be processed. dtype: Defines the data type of the tensors saved as attributes. device: Defines the device on which the module will be processed. dtype: Defines the data type of the tensors saved as attributes. """ if np.size(shape) == 1: shape_in = np.array(shape).item() shape = (shape_in, 2 * shape_in) else: raise ValueError("`shape` must be a one-element tuple `(Nin,)`.") # - Call super constructor super().__init__( shape=shape, spiking_input=False, spiking_output=True, *args, **kwargs, ) # - Default tensor construction parameters self._factory_kwargs = {"device": device, "dtype": dtype} # - Store layer parameters self.repeat_output: P_float = rp.SimulationParameter(repeat_output) self.n_ref_steps: P_float = rp.SimulationParameter(n_ref_steps) if np.size(thr_up) == 1: thr_up = torch.ones((1, self.size_in), **self._factory_kwargs) * thr_up else: thr_up = thr_up.view(1, -1) self.thr_up: P_tensor = rp.Parameter( thr_up, family="thresholds", ) """ (Tensor) Thresholds for creating up-spikes `(N_in,)` """ if np.size(thr_down) == 1: thr_down = torch.ones((1, self.size_in), **self._factory_kwargs) * thr_down else: thr_down = thr_up.view(1, -1) self.thr_down: P_tensor = rp.Parameter( thr_down, family="thresholds", ) """ (Tensor) Thresholds for creating down-spikes `(N_in,)` """
[docs] def evolve( self, input_data: torch.Tensor, record: bool = False, ) -> Tuple[Any, Any, Any]: # - Evolve with superclass evolution output_data, states, _ = super().evolve(input_data, record) # - Build state record record_dict = ( { "analog_value": self._analog_value_rec, } if record else {} ) return output_data, states, record_dict
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ forward method for processing data through this layer Convert each analog input channel to an up and down spike channel. Args: data (torch.Tensor): Data takes the shape of `(batch, time_steps, n_channels)` Raises: ValueError: Input has wrong input channel dimension. Returns: torch.Tensor: Output of spikes with the shape `(batch, time_steps, 2*n_channels)`, where the `2*n`-th output channel the up channel and the `(2*n + 1)`-th output channel the down channel of the `n`-th input channel are. """ n_batches, time_steps, n_channels = data.shape if n_channels != self.size_in: raise ValueError( "Input has wrong input channel dimension. It is {}, must be {}".format( n_channels, self.size_in ) ) # - Extend thresholds by batches thr_up = torch.ones(n_batches, 1) @ self.thr_up thr_down = torch.ones(n_batches, 1) @ self.thr_down """ Counter, for how many steps of dt is the module still in refractory period. Has to be counted for each batch and channel individually. """ remaining_ref_steps = torch.zeros(n_batches, n_channels) # - Reference value from where we observe whether the signal surpasses any thresholds analog_value = data[:, 0, :].detach() step_pwl = StepPWL.apply # - Set up state record and output self._analog_value_rec = torch.zeros( n_batches, time_steps, n_channels, **self._factory_kwargs ) out_spikes = torch.zeros( n_batches, time_steps, self.size_out, **self._factory_kwargs ) # - Loop over time for t in range(time_steps): # - Record the state self._analog_value_rec[:, t, :] = analog_value.detach() # - Get the difference between the last analog value saved diff_values = data[:, t, :] - analog_value # - Calculate the spike outputs up_channels = step_pwl(diff_values / thr_up) # - Enter the negative thr_down so that it checks for changes going below this threshold. down_channels = step_pwl(diff_values / (-thr_down)) if self.n_ref_steps > 0: # - Remove the spikes of all channels that are still in the refractory period up_channels[remaining_ref_steps > 0] = 0 down_channels[remaining_ref_steps > 0] = 0 # - Limit the amount of emitted spikes to 1, since the refractory period supresses all spikes after the first one up_channels[up_channels >= 1] = 1 down_channels[down_channels >= 1] = 1 # - Reset the refractory counter back to the full time when either an up or a down spike was emitted remaining_ref_steps[ (up_channels + down_channels) > 0 ] = self.n_ref_steps # - Set the reference value to the last input for all channels which are in refractory period analog_value[remaining_ref_steps > 0] = (data[:, t, :])[ remaining_ref_steps > 0 ] # - Count down the refractory counters remaining_ref_steps -= 1 else: # - Add (resp. subtract) the thresholds for every emitted spike analog_value = analog_value + up_channels * thr_up analog_value = analog_value - down_channels * thr_down # - Interleave up and down channels so we have 2*k as up and 2*k + 1 as down channel of the k-th input channel # - Multiply by repeat_output to get the desired multiple of spikes out_spikes[:, t, :] = self.repeat_output * torch.stack( (up_channels, down_channels), dim=2, ).view(n_batches, 2 * n_channels) return out_spikes