Source code for nn.modules.native.filter_bank

"""
Modules implementing filter banks
"""

from typing import Union, Iterable
from itertools import product
from multiprocessing import Pool

import numpy as np
from scipy.signal import butter, sosfilt, sosfreqz

from rockpool.nn.modules.module import Module
from rockpool.parameters import SimulationParameter

from typing import Optional, Tuple
from rockpool.typehints import P_int, P_float, P_bool

__all__ = ["ButterFilter", "ButterMelFilter"]


class FilterBankBase(Module):
    """
    Super-class to create a filter bank layer.
    This class actually does not build any filter, but it contains shared instances.
    """

    ## - Constructor
    def __init__(
        self,
        shape: Union[tuple, int] = (1, 64),
        fs: float = 44100.0,
        cutoff_fs: Optional[float] = 100.0,
        order: int = 2,
        mean_subtraction: bool = False,
        normalize: bool = False,
        num_workers: int = 1,
        use_lowpass: bool = True,
        *args,
        **kwargs,
    ):
        """
        :param tuple shape:             number of filters. Default: ``64``
        :param float fs:                input signal sampling frequency
        :param float cutoff_fs:         lowpass frequency to get only the enveloppe
                                        of filters output. Default: ``100 Hz``
        :param int order:               filter order. Default: ``2``
        :param bool mean_subtraction:   subtract the mean of output signals (per channel).
                                        Default ``False``
        :param bool normalize:          divide output signals by their maximum value (i.e. filter
                                        responses in the range [-1, 1]). Default: ``False``
        :param int num_workers:         Number of CPU cores to use in simulation. Default: ``1``
        :param bool use_lowpass:        Iff ``True``, use a low-pass filter following the band-pass filters. Default: ``True``
        """

        # - Correct the shape, if passed as an integer
        if not isinstance(shape, Iterable):
            shape = (1, shape)

        if shape[0] != 1:
            raise ValueError("The input dimension (`shape[0]`) must be `1`")

        # - Initialise the superclass
        super().__init__(
            shape=shape,
            *args,
            **kwargs,
        )

        assert fs > 0.0 and isinstance(
            fs, (int, float)
        ), f"`fs` must be a strictly positive float (given: {fs})"
        self.fs: P_float = SimulationParameter(fs, shape=())
        """ (float) Input sampling frequency in Hz """

        if cutoff_fs is not None:
            assert 0.0 < cutoff_fs < self.fs / 2 and isinstance(
                cutoff_fs, (int, float)
            ), f"`cutoff_fs` must be greater than 0 and lesser than `fs`/2 (given: {cutoff_fs})"
        self.cutoff_fs: P_float = SimulationParameter(cutoff_fs, shape=())
        """ (float) Post-filtering output low-pass cutoff frequency in Hz """

        assert order > 0 and isinstance(
            order, int
        ), f"`order` must be a strictly positive integer (given: {order})"
        self.order: P_int = SimulationParameter(order, shape=())
        """ (int) Filter order """

        assert num_workers > 0 and isinstance(
            num_workers, int
        ), f"`num_workers` must be a strictly postive integer (given: {num_workers})"
        self.num_workers: P_int = SimulationParameter(num_workers, shape=())
        """ (int) Number of workers to use in filtering """

        self.mean_subtraction: P_bool = SimulationParameter(mean_subtraction, shape=())
        """ (bool) Iff ``True``, subtract the mean filter output value from each output """

        self.normalize: P_bool = SimulationParameter(normalize, shape=())
        """ (bool) Iff ``True``, collectively normalise the filter outputs [-1, 1] """

        self.use_lowpass: P_bool = SimulationParameter(use_lowpass, shape=())
        """ (bool) Iff ``True``, perform a low-pass filter after filtering """

        # - Build low-pass filter
        self._filter_lowpass = (
            butter(
                self.order,
                self.cutoff_fs / (self.fs / 2),
                analog=False,
                btype="low",
                output="sos",
            )
            if use_lowpass
            else None
        )

        # - Initialise chunks and filters
        self._chunks: list = []
        self._filters: list = []

        # - Initialise worker pool
        # self._pool = Pool(self.num_workers)

    def _terminate(self):
        """Terminates all processes in the worker _pool"""

        if self._pool is not None:
            self._pool.close()

    @staticmethod
    def _generate_chunks(l, n) -> list:
        """Generates chunks of data"""

        chunks = []
        for i in range(0, len(l), n):
            if i + n > len(l):
                chunks.append(l[i:])
            else:
                chunks.append(l[i : i + n])
        return chunks

    @staticmethod
    def _process_filters(args) -> list:
        """Method for processing the filters each worker executes"""

        filters, params = args
        signal, filter_lowpass = params
        filters_output = []
        for f in filters:
            sig = sosfilt(f, signal)
            if filter_lowpass is not None:
                sig = np.abs(sig)
                sig = sosfilt(filter_lowpass, sig)
            filters_output.append(sig)
        return filters_output

    def evolve(
        self,
        input: np.ndarray,
        *args,
        **kwargs,
    ) -> Tuple[np.ndarray, dict, dict]:
        """
        Evolve the state of the filterbanks, given an input

        :param np.ndarray input:   Raw input signal
        """

        # - Build arguments to map filters over input
        args = list(product(self._chunks, [(input.T[0], self._filter_lowpass)]))

        # - Map the filtering process over the worker pool
        # res = self._pool.map(self._process_filters, args)
        res = list(map(self._process_filters, args))

        # - Combine the results
        filtOutput = np.concatenate(res).T

        # - Normalise the filter outputs
        if self.normalize:
            filtOutput /= np.max(np.abs(filtOutput))

        # - Mean-subtract the filter outputs
        if self.mean_subtraction:
            filtOutput -= np.mean(filtOutput)

        # - Return outputs
        return filtOutput, {}, {}


[docs]class ButterMelFilter(FilterBankBase): """ Define a Butterworth filter bank (mel spacing) filtering layer with continuous sampled output """ ## - Constructor
[docs] def __init__( self, shape: Union[tuple, int] = (1, 64), fs: float = 44100.0, cutoff_fs: float = 100.0, filter_width: float = 2.0, mean_subtraction: bool = False, normalize: bool = False, order: int = 2, num_workers: int = 1, plot: bool = False, use_lowpass: bool = True, *args, **kwargs, ): """ Layer which applies the butterworth filter in MEL scale to a one-dimensional input signal. Further dimensions can be passed through the layer without being filtered. :param tuple shape: Module shape ``(1, N)`` :param float fs: input signal sampling frequency :param str name: name of the layer. Default ``"unnamed"`` :param float cutoff_fs: lowpass frequency to get only the enveloppe of filters output. Also the lowest frequency of the filter bank. Default: ``100 Hz`` Don't set it yourself unless you know what you're doing. :param float filter_width: The width of the filters which is scaled with the number of filters. This determines the overlap between channels. Default: 2. :param int order: filter order. Default: ``2`` :param bool mean_subtraction: subtract the mean of output signals (per channel). Default ``False`` :param bool normalize: divide output signals by their maximum value (i.e. filter responses in the range [-1, 1]). Default: ``False`` :param int num_workers: Number of CPU cores to use in simulation. Default: ``1`` :param bool use_lowpass: Iff ``True``, return the filtered rectified smoothed signal. Default: ``True``. If ``False``, simply perform the band-pass filtering. :param bool plot: Plots the filter response. Default: ``False`` """ # - Call super constructor (`asarray` is used to strip units) super().__init__( shape=shape, fs=fs, cutoff_fs=cutoff_fs, order=order, mean_subtraction=mean_subtraction, normalize=normalize, num_workers=num_workers, use_lowpass=use_lowpass, *args, **kwargs, ) def hz2mel(x: Union[float, np.array]) -> Union[float, np.array]: """Takes value from hz and returns mel""" return 2595 * np.log10(1 + x / 700) def mel2hz(x: Union[float, np.array]) -> Union[float, np.array]: """ Takes value from mel and returns hz """ return 700 * (10 ** (x / 2595) - 1) filter_bandwidth = filter_width / self.shape[-1] low_freq = hz2mel(self.cutoff_fs) high_freq = hz2mel((self.fs / 2) / (1 + filter_bandwidth) - 1) freqs = mel2hz(np.linspace(low_freq, high_freq, self.shape[-1])) if np.max(freqs * (1 + filter_bandwidth) / (self.fs / 2)) >= 1.0: raise ValueError( "{} `{}`: `cutoff_fs` is too large (given: {})".format( self.__class__.__name__, self.name, self.cutoff_fs ) ) freq_bands = np.array([freqs, freqs * (1 + filter_bandwidth)]) / (self.fs / 2) self._filters = list( map( lambda fb: butter( self.order, fb, analog=False, btype="band", output="sos" ), freq_bands.T, ) ) # - Generate chunks chunk_size = int(np.ceil(self.shape[-1] / num_workers)) self._chunks = self._generate_chunks(self._filters, chunk_size) if plot: import matplotlib.pyplot as plt from matplotlib import cm colors = cm.Blues(np.linspace(0.5, 1, len(self._filters))) plt.figure(figsize=(16, 10)) for i, filt in enumerate(self._filters): sos_freqz = sosfreqz(filt, worN=1024) db = 20 * np.log10(np.maximum(np.abs(sos_freqz[1]), 1e-5)) plt.plot((self.fs / 2) * sos_freqz[0] / np.pi, db, color=colors[i]) plt.xlabel("Frequency (Hz)") plt.ylabel("Gain (db)") plt.ylim([-10, 2]) plt.xlim([0, self.fs / 2]) plt.tight_layout() plt.show(block=True)
[docs]class ButterFilter(FilterBankBase): """ Define a Butterworth filter bank filtering layer with continuous output """ ## - Constructor
[docs] def __init__( self, frequency: Union[float, np.ndarray], bandwidth: Union[float, np.ndarray], fs: float = 44100.0, order: int = 2, mean_subtraction: bool = False, normalize: bool = False, num_workers: int = 1, use_lowpass: bool = True, *args, **kwargs, ): """ Layer which applies the butterworth filter to a one-dimensional input signal. :param array frequency: frequency center positions of filters (low bound: where the filter response start to be maximal) the size determines the number of filters :param (float, array) bandwidth: filters response bandwidth (high bound: frequency + bandwidth) :param float fs: input signal sampling frequency in Hz. Default: 44100. :param str name: name of the layer. Default ``"unnamed"`` :param int order: filter order. Default: ``2`` :param bool mean_subtraction: subtract the mean of output signals (per channel). Default ``False`` :param bool normalize: divide output signals by their maximum absolute value. Default: ``False`` :param int num_workers: number of CPU cores to use in simulation. Default: ``1`` """ # - Check input arguments frequency = np.array(frequency).reshape((np.size(frequency),)) if np.size(bandwidth) == 1: bandwidth = np.ones(frequency.shape) * bandwidth else: bandwidth = np.asarray(bandwidth) if np.size(frequency) != np.size(bandwidth): raise ValueError( f"`bandwidth` must be either a scalar or of the same size than `frequency`. Got {np.size(frequency)} and {np.size(bandwidth)}" ) if np.any(frequency - bandwidth / 2 <= 0.0): raise ValueError("`frequency` must be greater than `bandwidth` / 2") if np.any(frequency + bandwidth / 2 > fs / 2): raise ValueError("`frequency` must be lesser than (`fs` - `bandwidth`) / 2") idx = np.argmin(frequency) cutoff_fs = frequency[idx] - bandwidth[idx] / 2 # - Call super constructor super().__init__( shape=np.size(frequency), fs=fs, cutoff_fs=cutoff_fs, order=order, mean_subtraction=mean_subtraction, normalize=normalize, num_workers=num_workers, use_lowpass=use_lowpass, *args, **kwargs, ) # - Add parameters self.frequency: P_float = SimulationParameter(frequency) """ (np.ndarray) Vector of centre frequencies for the filters, in Hz """ self.bandwidth: P_float = SimulationParameter(bandwidth) """ (np.ndarray) Vector of bandwidths of each filter, in Hz""" freq_bands = np.array( [ self.frequency - self.bandwidth / 2, self.frequency + self.bandwidth / 2, ] ) / (self.fs / 2) # - Build the filters self._filters = list( map( lambda fb: butter( self.order, fb, analog=False, btype="band", output="sos" ), freq_bands.T, ) ) # - Generate chunks chunk_size = int(np.ceil(self.shape[-1] / num_workers)) self._chunks = self._generate_chunks(self._filters, chunk_size)