# -----------------------------------------------------------
# simulates divisive normalization for a system with input spikes
# coming in possibly several channels
#
# (C) Saeid Haghighatshoar
# email: saeid.haghighatshoar@synsense.ai
#
# finalized and tested on 13.09.2021
# -----------------------------------------------------------
from rockpool.timeseries import TSEvent
from rockpool.nn.modules.module import Module
from rockpool.typehints import P_int, P_float, P_ndarray
from rockpool.parameters import State, SimulationParameter
import numpy as np
import warnings
from typing import Tuple, Union
from enum import IntEnum
# - Locate the base directory for SYNS61201, for use in generating the LFSR code
import inspect
import pathlib as pl
import rockpool.typehints as th
basedir = (
    pl.Path(*pl.Path(inspect.getfile(th)).parts[:-1]) / "devices" / "xylo" / "syns61201"
)
# - Define exports
__all__ = [
    "LowPassMode",
    "DivisiveNormalisation",
    "DivisiveNormalisationNoLFSR",
    "build_lfsr",
]
# - Build an enumeration for low pass mode
LowPassMode = IntEnum(
    "LowPassMode",
    "UNDERFLOW_PROTECT OVERFLOW_PROTECT",
    module=__name__,
    qualname="rockpool.devices.xylo.xylo_divisive_normalisation.LowPassMode",
)
def build_lfsr(filename) -> np.ndarray:
    """
    This function reads the LFSR code in a binary format from the file "filename"
    It return a Numpy array containing the integer values of LFSR code.
    """
    # read the LFSR state and build the pseudo-random code
    with open(filename) as f:
        lines = f.readlines()
    # for some reason state 'all-zero' is included in the file but
    # LFSR state cannot be all-zero because then the next state also would be
    # all-zero
    code_lfsr = np.zeros(len(lines), dtype="int")
    for i in range(len(lines)):
        code_lfsr[i] = int(lines[i], 2)
    # remove the last element (duplicate)
    code_lfsr = code_lfsr[:-1]
    return code_lfsr
[docs]class DivisiveNormalisation(Module):
    """
    A digital divisive normalization block
    """
[docs]    def __init__(
        self,
        shape: Union[int, Tuple[int]] = 1,
        *args,
        bits_counter: int = 10,
        E_frame_counter: np.ndarray = None,
        IAF_counter: np.ndarray = None,
        bits_lowpass: int = 16,
        bits_shift_lowpass: int = 5,
        M_lowpass_state: int = None,
        dt: P_float = 0.1e-3,
        frame_dt: float = 50e-3,
        bits_lfsr: int = 10,
        code_lfsr: np.ndarray = None,
        p_local: int = 12,
        low_pass_mode: LowPassMode = LowPassMode.UNDERFLOW_PROTECT,
        **kwargs,
    ):
        """
        Args:
            shape (tuple): The number of channels ``N`` for this Module
            bits_counter (int): Bit-width of frame counter E. Defualt: ``10``
            E_frame_counter (np.ndarray): Initialisation for frame counter E ``(N,)``. Default: ``None``, initialise to zero.
            IAF_counter (np.ndarray): Initialisation for IAF state ``(N,)``. Defualt: ``None``, initialise to zero.
            bits_lowpass (int): Number of bits used by the low-pass filter ``M``. Default: ``16``
            bits_shift_lowpass (int): Dash bitshift averaging parameter to use in low-pass filter. Default: ``5``
            M_lowpass_state (np.ndarray): Initialisation for low-pass state ``(N,)``. Default: ``None``, initialise to zero.
            dt (float): Global clock step in seconds. Default: 0.1ms
            frame_dt (float): Frame clock step in seconds. Default: 50ms
            bits_lfsr (int): Bit-width of LFSR. Default: ``10``
            code_lfsr (np.ndarray): LFSR sequence to use. Default: Load a pre-defined LFSR sequence.
            p_local (int): Factor to multiply spike rate E. Factor ``p`` is given by ``p_local / 2``.
            low_pass_mode (LowPassMode): Specify how to compute the low-pass filter M. Possible values are ``LowPassMode.UNDERFLOW_PROTECT`` and ``LowPassMode.OVERFLOW_PROTECT``. Default: ``LowPassMode.UNDERFLOW_PROTECT``, optimised for low input event frequencies. ``LowPassMode.OVERFLOW_PROTECT`` is optimal for high input frequencies.
        """
        # intialize the Module superclass
        super().__init__(
            shape, spiking_input=True, spiking_output=True, *args, **kwargs
        )
        # initialize the value of the counter or set to zero if not specified
        self.E_frame_counter: P_ndarray = State(
            E_frame_counter,
            shape=(self.size_in),
            init_func=lambda s: np.zeros(s, "int"),
        )
        """ np.ndarray: Spike rate per frame ``(N,)`` """
        self.bits_counter: P_int = SimulationParameter(bits_counter)
        """ int: Number of bits of spike rate counter E """
        # initialize the state of IAF_counter
        self.IAF_counter: P_ndarray = State(
            IAF_counter, shape=(self.size_in), init_func=lambda s: np.zeros(s, "int")
        )
        """ np.ndarray: IAF state ``(N,)`` """
        # set the parameters of the low-pass filter (implemented by bit-shifts)
        self.bits_lowpass: P_int = SimulationParameter(bits_lowpass)
        """ int: Number of bits in low-pass filter state M """
        self.bits_shift_lowpass: P_int = SimulationParameter(bits_shift_lowpass)
        """ int: Dash decay parameter for low-pass filter M """
        # at the moment we implement averaging filter such that its value is
        # always a positive integer. We can do much better by using some
        # fixed-point implementation by using some extra bits as decimals.
        # This improvment is left for future.
        self.M_lowpass_state: P_ndarray = State(
            M_lowpass_state,
            shape=(self.size_in),
            init_func=lambda s: np.zeros(s, "int"),
        )
        """ np.ndarray: Low-pass filter state M ``(N,)`` """
        ##
        # set the global clock frequency
        self.dt: P_float = SimulationParameter(dt)
        """ float: Global clock step in seconds """
        # set the period of the frame
        self.frame_dt: P_float = SimulationParameter(frame_dt)
        """ float: Frame clock step in seconds """
        # set the number of bits and also code for LFSR
        self.bits_lfsr: P_int = SimulationParameter(bits_lfsr)
        """ int: Number of LFSR bits """
        if code_lfsr is None:
            # - Load default LFSR sequence
            code_lfsr = build_lfsr(basedir / "lfsr_data.txt")
        if code_lfsr.size != 2**self.bits_lfsr - 1:
            raise ValueError(
                f"Length of LFSR is not compatible with its number of bits. Expected {2 ** self.bits_lfsr - 1} entries, found {code_lfsr.size}."
            )
        self.code_lfsr: P_ndarray = SimulationParameter(
            np.copy(code_lfsr).reshape((-1,))
        )
        """ np.ndarray: LFSR sequence to use for pRNG """
        self.lfsr_index: P_int = State(0)
        """ int: Current index into LFSR sequence """
        # set the ratio between the rate of local and global clocks
        # note that because of return-to-zero pulses, the spike rate increases
        # by only p_local/2 -> set p_local to be an even number
        self.p_local: P_int = int((1 + p_local) / 2) * 2
        """ int: Factor to scale up internal spike generation."""
        if self.p_local != p_local:
            warnings.warn(f"`p_local` = {p_local} was rounded to an even integer!")
        if low_pass_mode not in LowPassMode:
            raise ValueError(
                f"Unexpected value for `low_pass_mode`: {low_pass_mode}. Expected {[str(e) for e in LowPassMode]}"
            )
        self.low_pass_mode = SimulationParameter(low_pass_mode)
        """ LowPassMode: Specifies which mode to use for low-pass filtering """ 
[docs]    def _low_pass_underflow_protect(
        self, E_t: np.ndarray, M_t: np.ndarray
    ) -> np.ndarray:
        """
        Implement one low-pass filter time-step, with underflow protection
        Args:
            E_t (np.ndarray): Input rates for this frame ``(N,)`
            M_t (np.ndarray): Current low-pass state from previous frame ``(N,)``
        Returns:
            np.ndarray: Low-pass state for the next frame ``(N,)``
        """
        return (E_t + (M_t << self.bits_shift_lowpass) - M_t) >> self.bits_shift_lowpass 
[docs]    def _low_pass_overflow_protect(
        self, E_t: np.ndarray, M_t: np.ndarray
    ) -> np.ndarray:
        """
        Implement one low-pass filter time-step, with overflow protection
        Args:
            E_t (np.ndarray): Input rates for this frame ``(N,)`
            M_t (np.ndarray): Current low-pass state from previous frame ``(N,)``
        Returns:
            np.ndarray: Low-pass state for the next frame ``(N,)``
        """
        return (E_t >> self.bits_shift_lowpass) + M_t - (M_t >> self.bits_shift_lowpass) 
[docs]    def evolve(
        self, input_spike: np.ndarray, record: bool = False
    ) -> Tuple[np.ndarray, dict]:
        """
        This class simulates divisive normalization for an input spike signal
        with one or several channels.
        The output of the simulation is another spike signal with normalized rates.
        """
        # check the dimensionality first
        if input_spike.shape[1] != self.size_in:
            raise ValueError(
                f"Input size {input_spike.shape} did not match number of channels {self.size_in}"
            )
        # - Convert input spikes with duration 'dt' to frames of duration 'frame_dt'-> counter output
        # - output is counter output E(t) of duration 'frame_dt'
        # - input : (N, self.size_in) -> N is units of 'dt'
        # - E: (n_frame, self.size_in) -> units of 'frame_dt'
        ts_input = TSEvent.from_raster(
            input_spike, dt=self.dt, num_channels=self.size_in
        )
        E = ts_input.raster(dt=self.frame_dt, add_events=True)
        num_frames = E.shape[0]
        # add the effect of initial values in E_frame_counter
        E[0, :] += self.E_frame_counter.astype(int)
        # clip the counter to take the limited number of bits into account
        E = np.clip(E, 0, 2**self.bits_counter)
        # Reset the value of E_frame_counter
        self.E_frame_counter = np.zeros(self.size_in, "int")
        # Perform low-pass filter on E(t)-> M(t)
        # M(t) = s * E(t) + (1-s) M(t-1)
        #  with s=1/2**bits_shift_lowpass
        # M: (n_frame, self.size_in) -> units of 'frame_dt'
        M = np.zeros((num_frames + 1, self.size_in), dtype="int")
        # - Select the low-pass implementation
        if self.low_pass_mode is LowPassMode.UNDERFLOW_PROTECT:
            low_pass = self._low_pass_underflow_protect
        elif self.low_pass_mode is LowPassMode.OVERFLOW_PROTECT:
            low_pass = self._low_pass_overflow_protect
        else:
            raise ValueError(
                f"Unexpected value for `.low_pass_mode`: {self.low_pass_mode}. Expected {[str(e) for e in LowPassMode]}"
            )
        # load the initialization of the filter
        M[0, :] = self.M_lowpass_state
        # - Perform the low-pass filtering
        for t in range(num_frames):
            M[t + 1, :] = low_pass(E[t, :], M[t, :])
        # - Trim the first entry (initial state)
        M = M[1:, :]
        # take the limited number of counter bits into account
        # we should make sure that the controller does not allow count-back to zero
        # i.e., it keeps the value of the counter at its maximum
        M = np.clip(M, 0, 2**self.bits_lowpass - 1)
        self.M_lowpass_state = M[-1, :]
        # use the value of E(t) at each frame t to produce a pseudo-random
        # Poisson spike train by comparing E(t) with the LFSR output
        # as the value of LFSR varies with global clock rate f_s, we have 'frame_dt*f_s'
        # samples in each frame
        # the timing of the output is in units of 'dt'
        # Number of global clock cycles within a frame period
        cycles_per_frame = int(np.ceil(self.frame_dt / self.dt))
        # whole number of LFSR cycles that are needed for comparison with E(t)
        # over all frames t=0, 1, ...
        lfsr_ticks_needed = cycles_per_frame * num_frames
        # number of LFSR periods needed
        num_lfsr_period = int(np.ceil(lfsr_ticks_needed / self.code_lfsr.size)) + 1
        # the slice of LFSR code used over this frame
        code_lfsr_frame = np.tile(self.code_lfsr, num_lfsr_period)[
            self.lfsr_index : self.lfsr_index + lfsr_ticks_needed
        ].reshape(
            num_frames, -1
        )  # (frames, cycles_per_frame)
        self.lfsr_index = (self.lfsr_index + lfsr_ticks_needed) % len(self.code_lfsr)
        # initialize the IAF_state for further inspection
        if record:
            IAF_state_saved = [[] for _ in range(self.size_in)]
        # initialise output spike variables
        output_spike_times = [[] for _ in range(self.size_in)]
        output_spike_channels = [[] for _ in range(self.size_in)]
        # perform operation per channel
        for ch in range(self.size_in):
            # for each channel
            E_ch = np.copy(E[:, ch])
            M_ch = np.copy(M[:, ch])
            # repeat Each E(t) 'cycle_per_frame' times -> E(t) is compared with LFSR slice this many time.
            E_ch = E_ch.reshape(num_frames, -1)  # (frames, 1)
            E_ch_rep = np.tile(
                E_ch, (1, cycles_per_frame)
            )  # (frames, cycles_per_frame,)
            # Spike train generated by SG: each row contains spikes generated in a specific frame
            # units of time are 'dt'
            S_sg = np.int_(E_ch_rep >= code_lfsr_frame)  # (frames, cycles_per_frame,)
            # multiply the frequency of spikes by a factor p_local/2
            # (i) unwrap all frames in time (column vec)
            # (ii) repeat along column to simulate the effect of local clock
            # (iii) zero-pad each pulse (row) to take the return-to-zero pulse shape into account
            # (iv) wrap again to have the expanded frame in a row
            # (i)-(ii)
            # repeat the spikes in each frame by a factor p_local/2
            # note that due to return-to-zero pulse, each spike needs to be zero padded at the output
            # each row in the following array contains a pulse produced by SG expanded by p_local
            # by the local clock generator
            S_local_before_pad = np.tile(
                S_sg.reshape(S_sg.size, -1), (1, int(self.p_local / 2))
            )  # (frames * cycles_per_frame, p_local/2)
            # (iii)
            # now zero-pad each row containing a pulse modulated by local clock
            # although the results are 'uint', we use 'int' to avoid unsigned difference issue
            S_local = np.zeros(
                (S_local_before_pad.shape[0], self.p_local),
                dtype="int",
            )  # (frames * cycles_per_frame, p_local)
            S_local[:, : int(self.p_local / 2)] = S_local_before_pad
            # (iv)
            # now reshape S_local so that pulses corresponding to a specific frame are in the same row
            # this is needed because the threshold of IAF 'M(t)' changes from frame to frame
            S_local = S_local.reshape(
                num_frames, -1
            )  # (frames, cycles_per_frame * p_local)
            # apply IAF with threshold M(t) at each frame t (each row of S_local)
            # due to surplus from frame t-> t+1, we need to do this frame by frame
            for t in range(num_frames):
                # find the largest integer less than the floating-value threshold M(t)
                # this way of thresholding works because the IAF is implemented by counter
                # so, we need to set the value of threshold to be an integer value
                # some care is needed when M(t)<1 because then IAF fires everytime a
                # spike comes from the local generator
                # we solve this by simply adding the threshold by 1
                # threshold = np.ceil(M_ch[t]).astype("int") + 1
                # in this implemntation: special case of integer-valued threshold M(t)
                threshold = M_ch[t] + 1
                # compute the cumulative number of firings starting from residual and take mode
                IAF_state = (
                    np.concatenate(([self.IAF_counter[ch]], np.cumsum(S_local[t, :])))
                    % threshold
                )
                # save if needed
                if record:
                    IAF_state_saved[ch].append(np.copy(IAF_state[1:]))
                # to find the firing times, we need to find those times for which IAF_state[t]-IAF_state[t+1]<0
                # +1 is needed because of the delay we added
                firing_times_in_frame = np.argwhere(
                    (IAF_state[1:] - IAF_state[0:-1]) < 0
                ).reshape(-1)
                # register these firing times
                # output_spike_ch[t, firing_times_in_frame] = 1
                output_spike_times[ch].append(
                    firing_times_in_frame * (self.dt / self.p_local) + t * self.frame_dt
                )
                # Save the IAF state for the next frame
                self.IAF_counter[ch] = np.copy(IAF_state[-1])
            # collect all the firing times in all frames in a single array
            if record:
                IAF_state_saved[ch] = np.concatenate(IAF_state_saved[ch])
            # - Build a channels list for the spikes for this channel
            output_spike_times[ch] = np.concatenate(output_spike_times[ch])
            output_spike_channels[ch] = ch * np.ones(len(output_spike_times[ch]))
        # - Sort times and channels
        sorted_indices = np.argsort(np.concatenate(output_spike_times))
        output_spike_times = np.concatenate(output_spike_times)[sorted_indices]
        output_spike_channels = np.concatenate(output_spike_channels)[sorted_indices]
        # - Build output spike raster via TSEvent
        # add_events=True -> allow multiple events in a single time-slot
        output_spike = TSEvent(
            output_spike_times,
            output_spike_channels,
            t_start=0.0,
            t_stop=num_frames * self.frame_dt,
            num_channels=self.size_in,
        ).raster(self.dt, add_events=True)
        # - Generate state record dictionary
        record_dict = (
            {
                "E": E,
                "M": M,
                "IAF_state": np.array(IAF_state_saved).T,
            }
            if record
            else {}
        )
        return output_spike, self.state(), record_dict  
class DivisiveNormalisationNoLFSR(DivisiveNormalisation):
    """
    Divisive normalisation block, with no LFSR spike generation but direct event passthrough
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        delattr(self, "bits_lfsr")
        delattr(self, "code_lfsr")
        delattr(self, "lfsr_index")
    def evolve(
        self, input_spike: np.ndarray, record: bool = False
    ) -> Tuple[np.ndarray, dict]:
        """
        This class simulates divisive normalization for an input spike signal
        with possibly several channels.
        The output of the simulation is another spike signal with normalized rates.
        """
        # check the dimensionality first
        if input_spike.shape[1] != self.size_in:
            raise ValueError(
                f"Input size {input_spike.shape} did not match number of channels {self.size_in}"
            )
        # - Convert input spikes with duration 'dt' to frames of duration 'frame_dt'-> counter output
        # - output is counter output E(t) of duration 'frame_dt'
        # - input : (N, self.size_in) -> N is units of 'dt'
        # - E: (n_frame, self.size_in) -> units of 'frame_dt'
        ts_input = TSEvent.from_raster(
            input_spike,
            dt=self.dt,
            num_channels=self.size_in,
        )
        E = ts_input.raster(dt=self.frame_dt, add_events=True)
        num_frames = E.shape[0]
        # add the effect of initial values in E_frame_counter
        E[0, :] += self.E_frame_counter.astype(int)
        # clip the counter to take the limited number of bits into account
        E = np.clip(E, 0, 2**self.bits_counter)
        # Reset the value of E_frame_counter
        self.E_frame_counter = np.zeros(self.size_in, "int")
        # Perform low-pass filter on E(t)-> M(t)
        # M(t) = s * E(t) + (1-s) M(t-1)
        #  with s=1/2**bits_shift_lowpass
        # M: (n_frame, self.size_in) -> units of 'frame_dt'
        M = np.zeros((num_frames + 1, self.size_in), dtype="int")
        # - Select the low-pass implementation
        if self.low_pass_mode is LowPassMode.UNDERFLOW_PROTECT:
            low_pass = self._low_pass_underflow_protect
        elif self.low_pass_mode is LowPassMode.OVERFLOW_PROTECT:
            low_pass = self._low_pass_overflow_protect
        else:
            raise ValueError(
                f"Unexpected value for `.low_pass_mode`: {self.low_pass_mode}. Expected {[str(e) for e in LowPassMode]}"
            )
        # load the initialization of the filter
        M[0, :] = self.M_lowpass_state
        # - Perform the low-pass filtering
        for t in range(num_frames):
            M[t + 1, :] = low_pass(E[t, :], M[t, :])
        # - Trim the first entry (initial state)
        M = M[1:, :]
        # take the limited number of counter bits into account
        # we should make sure that the controller does not allow count-back to zero
        # i.e., it keeps the value of the counter at its maximum
        M = np.clip(M, 0, 2**self.bits_lowpass - 1)
        self.M_lowpass_state = M[-1, :]
        # use the value of E(t) at each frame t to produce a pseudo-random
        # Poisson spike train by comparing E(t) with the LFSR output
        # as the value of LFSR varies with global clock rate f_s, we have 'frame_dt*f_s'
        # samples in each frame
        # the timing of the output is in units of 'dt'
        # Number of global clock cycles within a frame period
        cycles_per_frame = int(np.ceil(self.frame_dt / self.dt))
        # initialize the IAF_state for further inspection
        if record:
            IAF_state_saved = [[] for _ in range(self.size_in)]
        # record output spikes and their channels
        output_spike_times = [[] for _ in range(self.size_in)]
        output_spike_channels = [[] for _ in range(self.size_in)]
        # perform operation per channel
        for ch in range(self.size_in):
            # for each channel
            # copy the input spike signal and zero-pad it at the end
            # we have "cycles_per_frame" of global clock cycles and we make sure that
            # the length of the input signal is an integer multiple of this
            # due to return-to-zero pulse shape we need to add a zero
            input_copy = np.zeros((num_frames * cycles_per_frame, 2))
            input_copy[: input_spike.shape[0], 0] = input_spike[:, ch]
            # now we need to expand each pulse by a factor p_local/2
            input_copy = input_copy.repeat(
                self.p_local / 2
            )  # (p_local/2*num_frames*cycles_per_frame) * 1
            # and reshape the pulses into frames of size p_local*cycles_per_frame
            # S_local --> local spike generator
            S_local = input_copy.reshape(
                -1, self.p_local * cycles_per_frame
            )  # num_frames * (p_local*cycles_per_frame)
            # Note: after this step everything is just similar to the implementation with LFSR
            # apply IAF with threshold M(t) at each frame t (each row of S_local)
            # due to surplus from frame t-> t+1, we need to do this frame by frame
            # output_spike_ch = np.zeros(S_local.shape, dtype="int")
            # initialize the state of the corresponding counter
            # res_from_previous_frame = self.IAF_counter[ch]
            for t in range(num_frames):
                # find the largest integer less than the floating-value threshold M(t)
                # this way of thresholding works because the IAF is implemented by counter
                # so, we need to set the value of threshold to be an integer value
                # some care is needed when M(t)<1 because then IAF fires everytime a
                # spike comes from the local generator
                # we solve this by simply adding the threshold by 1
                threshold = M[t, ch] + 1
                # compute the cumulative number of firings starting from residual and take mode
                IAF_state = (
                    np.concatenate(([self.IAF_counter[ch]], np.cumsum(S_local[t, :])))
                    % threshold
                )
                # IAF_state = (
                #     self.IAF_counter[ch] + np.cumsum(S_local[t, :])
                # ) % threshold
                # save if needed
                if record:
                    IAF_state_saved[ch].append(np.copy(IAF_state[1:]))
                # to find the firing times, we need to find those times for which IAF_state[t]-IAF_state[t+1]<0
                # +1 is needed because of the delay we added
                firing_times_in_frame = np.argwhere(
                    (IAF_state[1:] - IAF_state[0:-1]) < 0
                ).reshape(-1)
                # register these firing times
                # output_spike_ch[t, firing_times_in_frame] = 1
                output_spike_times[ch].append(
                    firing_times_in_frame * (self.dt / self.p_local) + t * self.frame_dt
                )
                # Save the IAF state for the next frame
                self.IAF_counter[ch] = np.copy(IAF_state[-1])
                # res_from_previous_frame = output_spike_ch[t, -1]
            # register the state of the IAF counter
            # self.IAF_counter[ch] = res_from_previous_frame
            # unwrap the spikes and copy it in the output_spike for the channel
            # since we are not worried about modification: no need for copy -> ravel
            # output_spike[:, ch] = output_spike_ch.ravel()
            if record:
                IAF_state_saved[ch] = np.concatenate(IAF_state_saved[ch])
            # - Build a channels list for the spikes for this channel
            output_spike_times[ch] = np.concatenate(output_spike_times[ch])
            output_spike_channels[ch] = ch * np.ones(
                len(output_spike_times[ch]), dtype="int"
            )
        # - Sort times and channels
        sorted_indices = np.argsort(np.concatenate(output_spike_times))
        output_spike_times = np.concatenate(output_spike_times)[sorted_indices]
        output_spike_channels = np.concatenate(output_spike_channels)[sorted_indices]
        # - Build output spike raster via TSEvent
        output_spike = TSEvent(
            output_spike_times,
            output_spike_channels,
            t_start=0.0,
            t_stop=num_frames * self.frame_dt,
            num_channels=self.size_in,
        ).raster(self.dt, add_events=True)
        # - Generate state record dictionary
        record_dict = (
            {
                "E": E,
                "M": M,
                "IAF_state": np.array(IAF_state_saved).T,
            }
            if record
            else {}
        )
        return output_spike, self.state(), record_dict