Source code for devices.xylo.syns63300.imuif_sim

"""
Simulation of an analog IMU signal filtering front-end

Defines :py:class:`.IMUIFSim` module.

See Also:
    For example usage of the :py:class:`.IMUIFSim` Module, see :ref:`/devices/xylo-imu/imu-if.ipynb`
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn

import numpy as np
from rockpool.devices.xylo.syns63300.imuif.params import DEFAULT_FILTER_BANDS
from rockpool.devices.xylo.syns63300.imuif.utils import type_check
from rockpool.devices.xylo.syns63300.imuif import (
    IAFSpikeEncoder,
    ScaleSpikeEncoder,
    RotationRemoval,
    FilterBank,
    BandPassFilter,
)
from rockpool.nn.combinators import Sequential
from rockpool.nn.modules.module import Module
from rockpool.utilities.backend_management import backend_available

from rockpool.parameters import SimulationParameter

if backend_available("samna"):
    from samna.xyloImu.configuration import InputInterfaceConfig
else:
    InputInterfaceConfig = Any

__all__ = ["IMUIFSim"]


[docs]class IMUIFSim(Module): """ A :py:class:`.Module` that simulates the IMU signal preprocessing on Xylo IMU This module simulates the Xylo IMU front-end stage. This is a signal-to-event core that consists of rotation removal units, low-pass filters, and a spike generator. The module takes in a 3D IMU signal and outputs a spike train. See Also: For example usage of the :py:class:`.IMUIFSim` Module, see :ref:`/devices/xylo-imu/imu-if.ipynb` """
[docs] def __init__( self, shape: Optional[Union[Tuple, int]] = (3, 15), select_iaf_output: bool = False, bypass_jsvd: bool = False, filter_list: Optional[List[BandPassFilter]] = None, scale_values: Union[List[int], int] = 5, iaf_threshold_values: Union[List[int], int] = 1024, num_avg_bitshift: int = 4, SAH_period: int = 10, sampling_freq: float = 200.0, ) -> None: """ Object constructor Args: shape (Optional[Union[Tuple, int]], optional): the shape of the input-output transformation. Defaults to (3, 15). select_iaf_output (bool, optional): If true, the output of the module is encoded using IAF spike encoding. If false, the output of the module is encoded using scale spike encoding. Defaults to False. bypass_jsvd (bool, optional): If true, the module does not perform the rotation removal stage. Defaults to False. filter_list (Optional[List[BandPassFilter]], optional): the list of filters of the filterbank. Note that first 5 filters will apply to the first input channel, the second 5 filters apply to the second input channel, and the last 5 filters will apply to the 3rd input channel. Defaults to None, when it's None, the default values apply. scale_values (Union[List[int], int], optional): number of right-bit-shifts needed for down-scaling the input signal (per channel). Defaults to [5,5,5,5,5,5,5,5,5,5,5,5,5,5,5]. iaf_threshold_values (Union[List[int], int], optional): the thresholds of the IAF neurons (quantized). Default to [1024,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024,1024]. num_avg_bitshift (int): number of bit shifts used in the low-pass filter implementation. Default to 4. The effective window length of the low-pass filter will be `2**num_avg_bitshift` SAH_period (int): Sampling period that the signal is sampled and held, in number of timesteps. Defaults to 10. sampling_freq (float): Sampling frequency of the IMU interface. Default: ``200.`` """ super().__init__(shape=shape, spiking_input=False, spiking_output=True) if filter_list is None: filter_list = [ BandPassFilter.from_specification(*band) for band in DEFAULT_FILTER_BANDS ] if len(filter_list) != self.size_out: raise ValueError( f"the number of filters {len(filter_list)} does not match the number of output channels {self.size_out}" ) filter_bank = FilterBank(shape, *filter_list) if select_iaf_output: spike_encoder = IAFSpikeEncoder( shape=(self.size_out, self.size_out), threshold=iaf_threshold_values ) else: spike_encoder = ScaleSpikeEncoder( shape=(self.size_out, self.size_out), num_scale_bits=scale_values ) if bypass_jsvd: mod_IMUIF = Sequential(filter_bank, spike_encoder) else: rotation_removal = RotationRemoval( shape=(self.size_in, self.size_in), num_avg_bitshift=num_avg_bitshift, sampling_period=SAH_period, ) mod_IMUIF = Sequential(rotation_removal, filter_bank, spike_encoder) self.model = mod_IMUIF """The sequential module that simulates the IMU front-end""" self.dt = SimulationParameter(1 / sampling_freq) """ (float) Time-step of the encoding simulation in seconds """
[docs] @type_check def evolve( self, input_data: np.ndarray, record: bool = False ) -> Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: """Processes the input IMU signal sample-by-sample and generate spikes Args: input_data (np.ndarray): batched input data recorded from IMU sensor. It should be in integer format. (BxTx3) record (bool, optional): If True, the intermediate results are recorded and returned. Defaults to False. Returns: Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: output empty dictionary empty dictionary """ # - Auto batch and shape check input_data, _ = self._auto_batch(input_data) # - Ensure data is the correct dtype input_data = np.array(input_data, dtype=np.int64).astype(object) # - Encode data and return return self.model(input_data, record=record)
[docs] @classmethod def from_config(cls, config: InputInterfaceConfig) -> IMUIFSim: """Obtain an instance of IMUIFSim from a samna configuration object Args: config (InputInterfaceConfig): a samna object that encapsulates the hardware configuration such as register values Raises: TypeError: if the input is not the samna configuration object that is expected Returns: IMUIFSim: an instance of IMUIFSim """ if not isinstance(config, InputInterfaceConfig): raise TypeError( f"config must be an instance of `samna.xyloImu.configuration.InputInterfaceConfig`. We got {type(config)}" ) if config.enable != True: warn("IMUIF is not enabled in configuration!") # We could not use `config.delay_threshold` here because it does not affect the simulation filter_list = [] for i, (B_b, B_wf, B_af, a1, a2) in enumerate( zip( config.bpf_bb_values, config.bpf_bwf_values, config.bpf_baf_values, config.bpf_a1_values, config.bpf_a2_values, ) ): filter_list.append( BandPassFilter( B_b=B_b, B_wf=B_wf, B_af=B_af, a1=a1, a2=a2, ) ) return cls( shape=(3, 15), select_iaf_output=config.select_iaf_output, bypass_jsvd=config.bypass_jsvd, filter_list=filter_list, scale_values=config.scale_values, iaf_threshold_values=config.iaf_threshold_values, num_avg_bitshift=config.estimator_k_setting, SAH_period=config.update_matrix_threshold, )
[docs] @classmethod def from_specification(cls, *args, **kwargs) -> IMUIFSim: raise NotImplementedError( "Here we do not have any high-level specification that's different than __init__ parameters." )
[docs] def export_config(self) -> InputInterfaceConfig: """Export the current configuration of the IMUIF module Returns: InputInterfaceConfig: a samna object that encapsulates the hardware configuration such as register values """ default_config = InputInterfaceConfig() bypass_jsvd = True scale_values = default_config.scale_values iaf_threshold_values = default_config.iaf_threshold_values estimator_k_setting = default_config.estimator_k_setting update_matrix_threshold = default_config.update_matrix_threshold for module in self.model: if isinstance(module, FilterBank): bpf_bb_values = module.B_b_list bpf_bwf_values = module.B_wf_list bpf_baf_values = module.B_af_list bpf_a1_values = [a1 for a1 in module.a1_list] bpf_a2_values = [a2 for a2 in module.a2_list] elif isinstance(module, IAFSpikeEncoder): select_iaf_output = True iaf_threshold_values = module.threshold elif isinstance(module, ScaleSpikeEncoder): select_iaf_output = False scale_values = module.num_scale_bits elif isinstance(module, RotationRemoval): bypass_jsvd = False estimator_k_setting = module.num_avg_bitshift update_matrix_threshold = module.sampling_period else: raise TypeError( f"module {module} is not recognized as a module of IMUIF" ) # We could not use `config.delay_threshold` here because it does not affect the simulation config = InputInterfaceConfig( enable=True, bpf_a1_values=bpf_a1_values, bpf_a2_values=bpf_a2_values, bpf_baf_values=bpf_baf_values, bpf_bb_values=bpf_bb_values, bpf_bwf_values=bpf_bwf_values, bypass_jsvd=bypass_jsvd, estimator_k_setting=estimator_k_setting, iaf_threshold_values=iaf_threshold_values, scale_values=scale_values, select_iaf_output=select_iaf_output, update_matrix_threshold=update_matrix_threshold, ) return config