Source code for devices.xylo.syns63300.imuif.filterbank

"""
Hardware butterworth filter implementation for the Xylo IMU.
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

from rockpool.devices.xylo.syns63300.imuif.utils import (
    type_check,
    unsigned_bit_range_check,
    signed_bit_range_check,
)

from rockpool.devices.xylo.syns63300.imuif.params import (
    NUM_BITS,
    B_WORST_CASE,
    FILTER_ORDER,
    CLOCK_RATE,
    DEFAULT_FILTER_BANDS,
)

from rockpool.nn.modules.module import Module
from scipy.signal import butter
from numpy.linalg import norm

EPS = 0.001
"""Epsilon for floating point comparison"""

__all__ = ["FilterBank", "BandPassFilter"]


[docs]@dataclass class BandPassFilter: """ Class that instantiates a single quantised band-pass filter, as implemented on Xylo IMU hardware This class will design the best filter that meets a pass-band specification, using the class method :py:meth:`~.BandPassFilter.from_specification`. Alternatively, you can specify the `a1` and `a2` taps directly as a signed integer representation, as well as specifying the number of bits needed for specifying the filter. """ B_b: int = 6 """Bits needed for scaling b0""" B_wf: int = 8 """Bits needed for fractional part of the filter output""" B_af: int = 9 """Bits needed for encoding the fractional parts of taps""" a1: int = -36565 """Integer representation of a1 tap""" a2: int = 31754 """Integer representation of a2 tap""" scale_out: Optional[float] = None """A virtual scaling factor that is applied to the output of the filter, NOT IMPLEMENTED ON HARDWARE! That shows the surplus scaling needed in the output (accepted range is [0.5, 1.0])""" def __post_init__(self) -> None: """ Check the validity of the parameters. """ self.B_w = NUM_BITS + B_WORST_CASE + self.B_wf """Total number of bits devoted to storing the values computed by the AR-filter.""" self.B_out = NUM_BITS + B_WORST_CASE """Total number of bits needed for storing the values computed by the WHOLE filter.""" if self.scale_out is not None: if self.scale_out < 0.5 or self.scale_out > 1.0: raise ValueError( f"scale_out should be in the range [0.5, 1.0]. Got {self.scale_out}" ) unsigned_bit_range_check(self.B_b, n_bits=4) unsigned_bit_range_check(self.B_wf, n_bits=4) unsigned_bit_range_check(self.B_af, n_bits=4) signed_bit_range_check(self.a1, n_bits=17) signed_bit_range_check(self.a2, n_bits=17)
[docs] @type_check def compute_AR(self, signal: np.ndarray) -> np.ndarray: """ Compute the AR part of the filter in the block-diagram with the given parameters. Args: signal (np.ndarray): the quantized input signal in python-object integer format. Raises: OverflowError: if any overflow happens during the filter computation. Returns: np.ndarray: the output signal of the AR filter. """ # check that the input is within the valid range of block-diagram unsigned_bit_range_check(np.max(np.abs(signal)), n_bits=NUM_BITS - 1) output = [] # w[n], w[n-1], w[n-2] w = [0, 0, 0] for val in signal: # Computation after the clock edge w_new = (val << self.B_wf) + ( (-self.a2 * w[2] - self.a1 * w[1]) >> self.B_af ) w_new = w_new >> self.B_b w[0] = w_new # register shift at the rising edge of the clock w[1], w[2] = w[0], w[1] output.append(w[0]) # check the overflow: here we have the integer version unsigned_bit_range_check(np.max(np.abs(output)), n_bits=self.B_w - 1) # convert into numpy return np.array(output, dtype=np.int64).astype(object)
[docs] @type_check def compute_MA(self, signal: np.ndarray) -> np.ndarray: """ Compute the MA part of the filter in the block-diagram representation. Args: signal (np.ndarray): input signal (in this case output of AR part) of datatype `pyton.object`. Raises: OverflowError: if any overflow happens during the filter computation. Returns: np.ndarray: quantized filtered output signal. """ # check dimension if signal.ndim > 1: raise ValueError("input signal should be 1-dim.") sig_out = np.copy(signal) sig_out[2:] = sig_out[2:] - signal[:-2] # apply the last B_wf bitshift to get rid of additional scaling needed to avoid dead-zone in the AR part sig_out = sig_out >> self.B_wf # check the validity of the computed output unsigned_bit_range_check(np.max(np.abs(sig_out)), n_bits=self.B_out - 1) return np.array(sig_out, dtype=np.int64).astype(object)
@type_check def __call__(self, signal: np.ndarray): """ Combine the filtering done in the AR and MA part of the block-diagram representation. Args: sig_in (np.ndarray): quantized input signal of python.object integer type. Raises: OverflowError: if any overflow happens during the filter computation. Returns: np.nadarray: quantized filtered output signal. """ signal = self.compute_AR(signal) signal = self.compute_MA(signal) return signal
[docs] @classmethod def from_specification( cls, low_cut_off: float, high_cut_off: float, fs: float = CLOCK_RATE ) -> "BandPassFilter": """ Create a filter with the given upper and lower cut-off frequencies. Note that the hardware filter WOULD NOT BE EXACTLY THE SAME as the one specified here. This script finds the closest one possible Args: low_cut_off (float): The low cut-off frequency of the band-pass filter. high_cut_off (float): The high cut-off frequency of the band-pass filter. fs (float): The clock rate of the chip running the filters (in Hz). Defaults to 200. Returns: BandPassFilter: the filter with the given cut-off frequencies. """ if low_cut_off >= high_cut_off: raise ValueError( f"Low cut-off frequency should be smaller than the high cut-off frequency." ) elif low_cut_off <= 0: raise ValueError(f"Low cut-off frequency should be positive.") # IIR filter coefficients b, a = butter( N=FILTER_ORDER, Wn=(low_cut_off, high_cut_off), btype="bandpass", analog=False, output="ba", fs=fs, ) # --- Sanity Check --- # if np.max(np.abs(b)) >= 1: raise ValueError( "all the coefficients of MA part `b` should be less than 1!" ) if a[0] != 1: raise ValueError( "AR coefficients: `a` should be in standard format with a[0]=1!" ) if np.max(np.abs(a)) >= 2: raise ValueError( "AR coefficients seem to be invalid: make sure that all values a[.] are in the range (-2,2)!" ) b_norm = b / abs(b[0]) b_norm_expected = np.array([1, 0, -1]) if ( norm(b_norm - b_norm_expected) > EPS and norm(b_norm + b_norm_expected) > EPS ): raise ValueError( "in Butterworth filters used in Xylo-IMU the normalize MA part should be of the form [1, 0, -1]!" ) # compute the closest power of 2 larger that than b[0] B_b = int(np.log2(1 / abs(b[0]))) B_wf = B_WORST_CASE - 1 B_af = NUM_BITS - B_b - 1 # quantize the a-taps of the filter a_taps = (2 ** (B_af + B_b) * a).astype(np.int64) a1 = a_taps[1] a2 = a_taps[2] scale_out = b[0] * (2**B_b) return cls(B_b=B_b, B_wf=B_wf, B_af=B_af, a1=a1, a2=a2, scale_out=scale_out)
[docs]class FilterBank(Module): """ This class builds the block-diagram version of the filters, which is exactly as it is done in HW. NOTE: Here we have considered a collection of `candidate` band-pass filters that have the potential to be chosen and implemented by the algorithm team. Here we make sure that all those filters work properly. """
[docs] def __init__( self, shape: Optional[Union[Tuple, int]] = (3, 15), *args: Union[List[BandPassFilter], List[Tuple[float]]], ) -> None: """Build a FilterBank simulation by specifying pass bands for individual filters .. Examples:: # - Generate a default filterbank >>> FilterBank() # - Specify three filters, single input channel >>> Filterbank(3, (0.1, 5), (5, 10), (10, 20)) # - Combine existing band-pass filters into a filter bank >>> bpf1 = BandPassFilter(...) >>> bpf2 = BandPassFilter(...) >>> bpf3 = BandPassFilter(...) >>> FilterBank(3, bpf1, bpf2, bpf3) Args: shape (Optional[Union[Tuple, int]], optional): The number of input and output channels. Defaults to `(3, 15)`. *args: A list of `BandPassFilter`s to register to the filterbank. Defaults to `None`; use a default filterbank configuration. """ # - If an integer number of filters is provided, use one input channel only if isinstance(shape, int): shape = (1, shape) # - Check that input and output shapes match if shape[1] % shape[0] != 0: raise ValueError( f"The number of output channels should be a multiple of the number of input channels." ) super().__init__(shape=shape, spiking_input=False, spiking_output=False) if not args: args = [ BandPassFilter.from_specification(*band) for band in DEFAULT_FILTER_BANDS ] self._filters = [] for arg in args: if isinstance(arg, BandPassFilter): self._filters.append(arg) else: raise TypeError( f"Expected `BandPassFilter` or specifications, got {type(arg)} instead." ) if shape[1] != len(self._filters): raise ValueError( f"The output size ({shape[1]}) must match the number of filters {len(self._filters)} to compute filtered output!" ) self._channel_mapping = np.sort( [i % self.size_in for i in range(self.size_out)] ) """ Mapping from IMU channels to filter channels. Equal number of filters per channel by default, with all filters for each input channel in order of input channel. e.g. [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] or similar"""
[docs] @classmethod def from_specification( cls, shape: Tuple[int] = (3, 15), *args: List[Tuple[float]] ) -> "FilterBank": """ Create a filter bank with the given frequency bands. Args: *args (List[Tuple[float]]): A list of tuples containing the lower and upper cut-off frequencies of the filters. Returns: FilterBank: the filter bank with the given frequency bands. """ if not args: args = DEFAULT_FILTER_BANDS for arg in args: if not isinstance(arg, tuple): raise TypeError(f"Expected tuple, got {type(arg)} instead.") elif not (len(arg) == 2 or len(arg) == 3): raise ValueError( f"Expected tuple of length 2 or 3, got {len(arg)} instead." ) filter_list = [BandPassFilter.from_specification(*band) for band in args] return cls(shape, *filter_list)
[docs] @type_check def evolve( self, input_data: np.ndarray, record: bool = False ) -> Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: """ Compute the output of all filters for an input signal. Combine the filtering done in the `AR` and `MA` part of the block-diagram representation. Args: input_data (np.ndarray): the quantized input signal of datatype python.object integer. (BxTxC) Returns: Tuple[np.ndarray, Dict[str, Any], Dict[str, Any]]: np.ndarray: the filtered output signal of all filters (BxTxC) dict: empty record dictionary. dict: empty state dictionary. """ # -- Batch processing input_data, _ = self._auto_batch(input_data) input_data = np.array(input_data, dtype=np.int64).astype(object) # -- Revert and repeat the input signal in the beginning to avoid boundary effects __B, __T, __C = input_data.shape __input_data_rev = np.flip(input_data, axis=1) input_data = np.concatenate((__input_data_rev, input_data), axis=1) # -- Filter data_out = [] # iterate over batches for signal in input_data: channel_out = [] for __filter, __ch in zip(self._filters, self._channel_mapping): out = __filter(signal.T[__ch]) channel_out.append(out) data_out.append(channel_out) # convert into numpy data_out = np.asarray(data_out, dtype=object) data_out = data_out.transpose(0, 2, 1) # BxTxC # -- Cut the margin data_out = data_out[:, __T:, :] return data_out, {}, {}
@property def B_b_list(self) -> List[int]: """List of B_b values of all filters""" return [f.B_b for f in self._filters] @property def B_wf_list(self) -> List[int]: """List of B_wf values of all filters""" return [f.B_wf for f in self._filters] @property def B_af_list(self) -> List[int]: """List of B_af values of all filters""" return [f.B_af for f in self._filters] @property def a1_list(self) -> List[int]: """List of a1 values of all filters""" return [f.a1 for f in self._filters] @property def a2_list(self) -> List[int]: """List of a2 values of all filters""" return [f.a2 for f in self._filters]