Source code for devices.xylo.syns63300.imuif.spike_encoder

"""
This module implements the spike encoding for the signal coming out of filters or IMU sensor directly.
"""

from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

from rockpool.devices.xylo.syns63300.imuif.utils import (
    type_check,
    unsigned_bit_range_check,
)
from rockpool.devices.xylo.syns63300.imuif.params import NUM_BITS_SPIKE
from rockpool.nn.modules.module import Module
from rockpool.parameters import SimulationParameter


__all__ = ["ScaleSpikeEncoder", "IAFSpikeEncoder"]


[docs]class ScaleSpikeEncoder(Module): """ Encode spikes as follows (i) Apply full-wave rectification to the input signal (ii) Down-scale the input signal by right-bit-shift by `num_scale_bit` (e.g. multiplying with 1/2^num_scale_bits) (iii) Truncate the output so that it can fit within `NUM_BITS_SPIKE` """
[docs] def __init__( self, shape: Union[Tuple, int] = (15, 15), num_scale_bits: Union[List[int], int] = 5, ) -> None: """ Object constructor Args: shape (Union[Tuple, int]): The number of input and output channels. Defaults to (15, 15). num_scale_bits (Union[List[int], int]): number of right-bit-shifts needed for down-scaling the input signal. Defaults to ``5`` for each neuron. """ super().__init__(shape=shape, spiking_input=False, spiking_output=True) num_scale_bits = ( [num_scale_bits] * self.size_out if isinstance(num_scale_bits, int) else num_scale_bits ) [unsigned_bit_range_check(__s, n_bits=5) for __s in num_scale_bits] self.num_scale_bits = SimulationParameter( num_scale_bits, shape=(self.size_out,), ) """(int) Number of right-bit-shifts needed for down-scaling the input signal"""
[docs] @type_check def evolve( self, input_data: np.ndarray, record: bool = False ) -> Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: """Processes the input signal and return the encoded spikes Args: input_data (np.ndarray): filtered data recorded from IMU sensor. It should be in integer format. (BxTx15) record (bool): Unused. Returns: Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: Encoded spikes (BxTx48) empty dictionary empty dictionary """ input_data, _ = self._auto_batch(input_data) input_data = np.array(input_data, dtype=np.int64).astype(object) # Full-wave rectification output_data = np.abs(input_data) # scale the signal for ch, __scale in enumerate(self.num_scale_bits): output_data[:, :, ch] = output_data[:, :, ch] >> __scale # truncate the signal threshold = (1 << NUM_BITS_SPIKE) - 1 output_data[output_data > threshold] = threshold return output_data, {}, {}
[docs]class IAFSpikeEncoder(Module): """ Synchronous integrate and fire spike encoder More specifically, denoting a specific filter output by $$y(t)$$, this module computes $$s(t) = \\sum_{n=0}^t |y(n)|$$ and as soon as it passes the firing threshold $$\\theta$$, say at time $$t=t_0$$, it produces a spike at $$t=t_0$$ and reduces the counter by $$\theta$$ and keeps accumulating the input signal: $$s(t) = s(t_0) - \\theta + \\sum_{n=t_0+1}^t |y(n)|$$ until the next threshold crossing and spike generation occurs. """
[docs] def __init__( self, shape: Optional[Union[Tuple, int]] = (15, 15), threshold: Union[List[int], int] = 1024, ) -> None: """ Instantiate an IAF spike encoder Args: shape (Optional[Union[Tuple, int]]): the input and output shape of the module. Defaults to ``(15, 15)``. threshold (Union[int, List[int]]): the thresholds of the IAF neurons (quantized unsigned integer). Provide a single integer to use the same threshold for each neuron, or a list of thresholds with one for each neuron. Default: ``1024`` for each neuron. """ super().__init__(shape=shape, spiking_input=False, spiking_output=True) threshold = ( [threshold] * self.size_out if isinstance(threshold, int) else threshold ) [unsigned_bit_range_check(th, n_bits=31) for th in threshold] self.threshold = SimulationParameter(threshold, shape=(self.size_out,)) """ (np.ndarray) The unsigned integer thresholds of the IAF neurons"""
[docs] @type_check def evolve( self, input_data: np.ndarray, record: bool = False ) -> Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: """Processes the input signal and return the encoded spikes Args: input_data (np.ndarray): filtered data recorded from IMU sensor. It should be in integer format. (BxTx15) record (bool, optional): Unused. Returns: Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: Encoded spikes (BxTx15) empty dictionary empty dictionary """ input_data, _ = self._auto_batch(input_data) __B, __T, __C = input_data.shape input_data = np.array(input_data, dtype=np.int64).astype(object) # Full-wave rectification output_data = np.abs(input_data) # compute the cumsum along the time axis output_data = np.cumsum(output_data, axis=1) # compute the number of spikes produced so far for ch, __th in enumerate(self.threshold): output_data[:, :, ch] = output_data[:, :, ch] // __th # add a zero column to make sure that the dimensions match num_spikes = np.hstack([np.zeros((__B, 1, __C), dtype=object), output_data]) # compute the spikes along the time axis spikes = np.diff(num_spikes, axis=1) # if there are more than one spike per dt, truncate it to 1 np.clip(spikes, 0, 1, out=spikes) return spikes, {}, {}