# -----------------------------------------------------------
# simulates divisive normalization for a system with input spikes
# coming in possibly several channels
#
# (C) Saeid Haghighatshoar
# email: saeid.haghighatshoar@synsense.ai
#
# finalized and tested on 13.09.2021
# -----------------------------------------------------------
from rockpool.timeseries import TSEvent
from rockpool.nn.modules.module import Module
from rockpool.typehints import P_int, P_float, P_ndarray
from rockpool.parameters import State, SimulationParameter
import numpy as np
import warnings
import imp
import pathlib as pl
basedir = pl.Path(imp.find_module("rockpool")[1]) / "devices" / "xylo" / "syns61201"
from typing import Tuple, Union
from enum import IntEnum
__all__ = [
"LowPassMode",
"DivisiveNormalisation",
"DivisiveNormalisationNoLFSR",
"build_lfsr",
]
LowPassMode = IntEnum(
"LowPassMode",
"UNDERFLOW_PROTECT OVERFLOW_PROTECT",
module=__name__,
qualname="rockpool.devices.xylo.xylo_divisive_normalisation.LowPassMode",
)
def build_lfsr(filename) -> np.ndarray:
"""
This function reads the LFSR code in a binary format from the file "filename"
It return a Numpy array containing the integer values of LFSR code.
"""
# read the LFSR state and build the pseudo-random code
with open(filename) as f:
lines = f.readlines()
# for some reason state 'all-zero' is included in the file but
# LFSR state cannot be all-zero because then the next state also would be
# all-zero
code_lfsr = np.zeros(len(lines), dtype="int")
for i in range(len(lines)):
code_lfsr[i] = int(lines[i], 2)
# remove the last element (duplicate)
code_lfsr = code_lfsr[:-1]
return code_lfsr
[docs]class DivisiveNormalisation(Module):
"""
A digital divisive normalization block
"""
[docs] def __init__(
self,
shape: Union[int, Tuple[int]] = 1,
*args,
bits_counter: int = 10,
E_frame_counter: np.ndarray = None,
IAF_counter: np.ndarray = None,
bits_lowpass: int = 16,
bits_shift_lowpass: int = 5,
M_lowpass_state: int = None,
dt: P_float = 0.1e-3,
frame_dt: float = 50e-3,
bits_lfsr: int = 10,
code_lfsr: np.ndarray = None,
p_local: int = 12,
low_pass_mode: LowPassMode = LowPassMode.UNDERFLOW_PROTECT,
**kwargs,
):
"""
Args:
shape (tuple): The number of channels ``N`` for this Module
bits_counter (int): Bit-width of frame counter E. Defualt: ``10``
E_frame_counter (np.ndarray): Initialisation for frame counter E ``(N,)``. Default: ``None``, initialise to zero.
IAF_counter (np.ndarray): Initialisation for IAF state ``(N,)``. Defualt: ``None``, initialise to zero.
bits_lowpass (int): Number of bits used by the low-pass filter ``M``. Default: ``16``
bits_shift_lowpass (int): Dash bitshift averaging parameter to use in low-pass filter. Default: ``5``
M_lowpass_state (np.ndarray): Initialisation for low-pass state ``(N,)``. Default: ``None``, initialise to zero.
dt (float): Global clock step in seconds. Default: 0.1ms
frame_dt (float): Frame clock step in seconds. Default: 50ms
bits_lfsr (int): Bit-width of LFSR. Default: ``10``
code_lfsr (np.ndarray): LFSR sequence to use. Default: Load a pre-defined LFSR sequence.
p_local (int): Factor to multiply spike rate E. Factor ``p`` is given by ``p_local / 2``.
low_pass_mode (LowPassMode): Specify how to compute the low-pass filter M. Possible values are ``LowPassMode.UNDERFLOW_PROTECT`` and ``LowPassMode.OVERFLOW_PROTECT``. Default: ``LowPassMode.UNDERFLOW_PROTECT``, optimised for low input event frequencies. ``LowPassMode.OVERFLOW_PROTECT`` is optimal for high input frequencies.
"""
# intialize the Module superclass
super().__init__(
shape, spiking_input=True, spiking_output=True, *args, **kwargs
)
# initialize the value of the counter or set to zero if not specified
self.E_frame_counter: P_ndarray = State(
E_frame_counter,
shape=(self.size_in),
init_func=lambda s: np.zeros(s, "int"),
)
""" np.ndarray: Spike rate per frame ``(N,)`` """
self.bits_counter: P_int = SimulationParameter(bits_counter)
""" int: Number of bits of spike rate counter E """
# initialize the state of IAF_counter
self.IAF_counter: P_ndarray = State(
IAF_counter, shape=(self.size_in), init_func=lambda s: np.zeros(s, "int")
)
""" np.ndarray: IAF state ``(N,)`` """
# set the parameters of the low-pass filter (implemented by bit-shifts)
self.bits_lowpass: P_int = SimulationParameter(bits_lowpass)
""" int: Number of bits in low-pass filter state M """
self.bits_shift_lowpass: P_int = SimulationParameter(bits_shift_lowpass)
""" int: Dash decay parameter for low-pass filter M """
# at the moment we implement averaging filter such that its value is
# always a positive integer. We can do much better by using some
# fixed-point implementation by using some extra bits as decimals.
# This improvment is left for future.
self.M_lowpass_state: P_ndarray = State(
M_lowpass_state,
shape=(self.size_in),
init_func=lambda s: np.zeros(s, "int"),
)
""" np.ndarray: Low-pass filter state M ``(N,)`` """
##
# set the global clock frequency
self.dt: P_float = SimulationParameter(dt)
""" float: Global clock step in seconds """
# set the period of the frame
self.frame_dt: P_float = SimulationParameter(frame_dt)
""" float: Frame clock step in seconds """
# set the number of bits and also code for LFSR
self.bits_lfsr: P_int = SimulationParameter(bits_lfsr)
""" int: Number of LFSR bits """
if code_lfsr is None:
# - Load default LFSR sequence
code_lfsr = build_lfsr(basedir / "lfsr_data.txt")
if code_lfsr.size != 2**self.bits_lfsr - 1:
raise ValueError(
f"Length of LFSR is not compatible with its number of bits. Expected {2 ** self.bits_lfsr - 1} entries, found {code_lfsr.size}."
)
self.code_lfsr: P_ndarray = SimulationParameter(
np.copy(code_lfsr).reshape((-1,))
)
""" np.ndarray: LFSR sequence to use for pRNG """
self.lfsr_index: P_int = State(0)
""" int: Current index into LFSR sequence """
# set the ratio between the rate of local and global clocks
# note that because of return-to-zero pulses, the spike rate increases
# by only p_local/2 -> set p_local to be an even number
self.p_local: P_int = int((1 + p_local) / 2) * 2
""" int: Factor to scale up internal spike generation."""
if self.p_local != p_local:
warnings.warn(f"`p_local` = {p_local} was rounded to an even integer!")
if low_pass_mode not in LowPassMode:
raise ValueError(
f"Unexpected value for `low_pass_mode`: {low_pass_mode}. Expected {[str(e) for e in LowPassMode]}"
)
self.low_pass_mode = SimulationParameter(low_pass_mode)
""" LowPassMode: Specifies which mode to use for low-pass filtering """
[docs] def _low_pass_underflow_protect(
self, E_t: np.ndarray, M_t: np.ndarray
) -> np.ndarray:
"""
Implement one low-pass filter time-step, with underflow protection
Args:
E_t (np.ndarray): Input rates for this frame ``(N,)`
M_t (np.ndarray): Current low-pass state from previous frame ``(N,)``
Returns:
np.ndarray: Low-pass state for the next frame ``(N,)``
"""
return (E_t + (M_t << self.bits_shift_lowpass) - M_t) >> self.bits_shift_lowpass
[docs] def _low_pass_overflow_protect(
self, E_t: np.ndarray, M_t: np.ndarray
) -> np.ndarray:
"""
Implement one low-pass filter time-step, with overflow protection
Args:
E_t (np.ndarray): Input rates for this frame ``(N,)`
M_t (np.ndarray): Current low-pass state from previous frame ``(N,)``
Returns:
np.ndarray: Low-pass state for the next frame ``(N,)``
"""
return (E_t >> self.bits_shift_lowpass) + M_t - (M_t >> self.bits_shift_lowpass)
[docs] def evolve(
self, input_spike: np.ndarray, record: bool = False
) -> Tuple[np.ndarray, dict]:
"""
This class simulates divisive normalization for an input spike signal
with one or several channels.
The output of the simulation is another spike signal with normalized rates.
"""
# check the dimensionality first
if input_spike.shape[1] != self.size_in:
raise ValueError(
f"Input size {input_spike.shape} did not match number of channels {self.size_in}"
)
# - Convert input spikes with duration 'dt' to frames of duration 'frame_dt'-> counter output
# - output is counter output E(t) of duration 'frame_dt'
# - input : (N, self.size_in) -> N is units of 'dt'
# - E: (n_frame, self.size_in) -> units of 'frame_dt'
ts_input = TSEvent.from_raster(
input_spike, dt=self.dt, num_channels=self.size_in
)
E = ts_input.raster(dt=self.frame_dt, add_events=True)
num_frames = E.shape[0]
# add the effect of initial values in E_frame_counter
E[0, :] += self.E_frame_counter.astype(int)
# clip the counter to take the limited number of bits into account
E = np.clip(E, 0, 2**self.bits_counter)
# Reset the value of E_frame_counter
self.E_frame_counter = np.zeros(self.size_in, "int")
# Perform low-pass filter on E(t)-> M(t)
# M(t) = s * E(t) + (1-s) M(t-1)
# with s=1/2**bits_shift_lowpass
# M: (n_frame, self.size_in) -> units of 'frame_dt'
M = np.zeros((num_frames + 1, self.size_in), dtype="int")
# - Select the low-pass implementation
if self.low_pass_mode is LowPassMode.UNDERFLOW_PROTECT:
low_pass = self._low_pass_underflow_protect
elif self.low_pass_mode is LowPassMode.OVERFLOW_PROTECT:
low_pass = self._low_pass_overflow_protect
else:
raise ValueError(
f"Unexpected value for `.low_pass_mode`: {self.low_pass_mode}. Expected {[str(e) for e in LowPassMode]}"
)
# load the initialization of the filter
M[0, :] = self.M_lowpass_state
# - Perform the low-pass filtering
for t in range(num_frames):
M[t + 1, :] = low_pass(E[t, :], M[t, :])
# - Trim the first entry (initial state)
M = M[1:, :]
# take the limited number of counter bits into account
# we should make sure that the controller does not allow count-back to zero
# i.e., it keeps the value of the counter at its maximum
M = np.clip(M, 0, 2**self.bits_lowpass - 1)
self.M_lowpass_state = M[-1, :]
# use the value of E(t) at each frame t to produce a pseudo-random
# Poisson spike train by comparing E(t) with the LFSR output
# as the value of LFSR varies with global clock rate f_s, we have 'frame_dt*f_s'
# samples in each frame
# the timing of the output is in units of 'dt'
# Number of global clock cycles within a frame period
cycles_per_frame = int(np.ceil(self.frame_dt / self.dt))
# whole number of LFSR cycles that are needed for comparison with E(t)
# over all frames t=0, 1, ...
lfsr_ticks_needed = cycles_per_frame * num_frames
# number of LFSR periods needed
num_lfsr_period = int(np.ceil(lfsr_ticks_needed / self.code_lfsr.size)) + 1
# the slice of LFSR code used over this frame
code_lfsr_frame = np.tile(self.code_lfsr, num_lfsr_period)[
self.lfsr_index : self.lfsr_index + lfsr_ticks_needed
].reshape(
num_frames, -1
) # (frames, cycles_per_frame)
self.lfsr_index = (self.lfsr_index + lfsr_ticks_needed) % len(self.code_lfsr)
# initialize the IAF_state for further inspection
if record:
IAF_state_saved = [[] for _ in range(self.size_in)]
# initialise output spike variables
output_spike_times = [[] for _ in range(self.size_in)]
output_spike_channels = [[] for _ in range(self.size_in)]
# perform operation per channel
for ch in range(self.size_in):
# for each channel
E_ch = np.copy(E[:, ch])
M_ch = np.copy(M[:, ch])
# repeat Each E(t) 'cycle_per_frame' times -> E(t) is compared with LFSR slice this many time.
E_ch = E_ch.reshape(num_frames, -1) # (frames, 1)
E_ch_rep = np.tile(
E_ch, (1, cycles_per_frame)
) # (frames, cycles_per_frame,)
# Spike train generated by SG: each row contains spikes generated in a specific frame
# units of time are 'dt'
S_sg = np.int_(E_ch_rep >= code_lfsr_frame) # (frames, cycles_per_frame,)
# multiply the frequency of spikes by a factor p_local/2
# (i) unwrap all frames in time (column vec)
# (ii) repeat along column to simulate the effect of local clock
# (iii) zero-pad each pulse (row) to take the return-to-zero pulse shape into account
# (iv) wrap again to have the expanded frame in a row
# (i)-(ii)
# repeat the spikes in each frame by a factor p_local/2
# note that due to return-to-zero pulse, each spike needs to be zero padded at the output
# each row in the following array contains a pulse produced by SG expanded by p_local
# by the local clock generator
S_local_before_pad = np.tile(
S_sg.reshape(S_sg.size, -1), (1, int(self.p_local / 2))
) # (frames * cycles_per_frame, p_local/2)
# (iii)
# now zero-pad each row containing a pulse modulated by local clock
# although the results are 'uint', we use 'int' to avoid unsigned difference issue
S_local = np.zeros(
(S_local_before_pad.shape[0], self.p_local),
dtype="int",
) # (frames * cycles_per_frame, p_local)
S_local[:, : int(self.p_local / 2)] = S_local_before_pad
# (iv)
# now reshape S_local so that pulses corresponding to a specific frame are in the same row
# this is needed because the threshold of IAF 'M(t)' changes from frame to frame
S_local = S_local.reshape(
num_frames, -1
) # (frames, cycles_per_frame * p_local)
# apply IAF with threshold M(t) at each frame t (each row of S_local)
# due to surplus from frame t-> t+1, we need to do this frame by frame
for t in range(num_frames):
# find the largest integer less than the floating-value threshold M(t)
# this way of thresholding works because the IAF is implemented by counter
# so, we need to set the value of threshold to be an integer value
# some care is needed when M(t)<1 because then IAF fires everytime a
# spike comes from the local generator
# we solve this by simply adding the threshold by 1
# threshold = np.ceil(M_ch[t]).astype("int") + 1
# in this implemntation: special case of integer-valued threshold M(t)
threshold = M_ch[t] + 1
# compute the cumulative number of firings starting from residual and take mode
IAF_state = (
np.concatenate(([self.IAF_counter[ch]], np.cumsum(S_local[t, :])))
% threshold
)
# save if needed
if record:
IAF_state_saved[ch].append(np.copy(IAF_state[1:]))
# to find the firing times, we need to find those times for which IAF_state[t]-IAF_state[t+1]<0
# +1 is needed because of the delay we added
firing_times_in_frame = np.argwhere(
(IAF_state[1:] - IAF_state[0:-1]) < 0
).reshape(-1)
# register these firing times
# output_spike_ch[t, firing_times_in_frame] = 1
output_spike_times[ch].append(
firing_times_in_frame * (self.dt / self.p_local) + t * self.frame_dt
)
# Save the IAF state for the next frame
self.IAF_counter[ch] = np.copy(IAF_state[-1])
# collect all the firing times in all frames in a single array
if record:
IAF_state_saved[ch] = np.concatenate(IAF_state_saved[ch])
# - Build a channels list for the spikes for this channel
output_spike_times[ch] = np.concatenate(output_spike_times[ch])
output_spike_channels[ch] = ch * np.ones(len(output_spike_times[ch]))
# - Sort times and channels
sorted_indices = np.argsort(np.concatenate(output_spike_times))
output_spike_times = np.concatenate(output_spike_times)[sorted_indices]
output_spike_channels = np.concatenate(output_spike_channels)[sorted_indices]
# - Build output spike raster via TSEvent
# add_events=True -> allow multiple events in a single time-slot
output_spike = TSEvent(
output_spike_times,
output_spike_channels,
t_start=0.0,
t_stop=num_frames * self.frame_dt,
num_channels=self.size_in,
).raster(self.dt, add_events=True)
# - Generate state record dictionary
record_dict = (
{
"E": E,
"M": M,
"IAF_state": np.array(IAF_state_saved).T,
}
if record
else {}
)
return output_spike, self.state(), record_dict
class DivisiveNormalisationNoLFSR(DivisiveNormalisation):
"""
Divisive normalisation block, with no LFSR spike generation but direct event passthrough
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
delattr(self, "bits_lfsr")
delattr(self, "code_lfsr")
delattr(self, "lfsr_index")
def evolve(
self, input_spike: np.ndarray, record: bool = False
) -> Tuple[np.ndarray, dict]:
"""
This class simulates divisive normalization for an input spike signal
with possibly several channels.
The output of the simulation is another spike signal with normalized rates.
"""
# check the dimensionality first
if input_spike.shape[1] != self.size_in:
raise ValueError(
f"Input size {input_spike.shape} did not match number of channels {self.size_in}"
)
# - Convert input spikes with duration 'dt' to frames of duration 'frame_dt'-> counter output
# - output is counter output E(t) of duration 'frame_dt'
# - input : (N, self.size_in) -> N is units of 'dt'
# - E: (n_frame, self.size_in) -> units of 'frame_dt'
ts_input = TSEvent.from_raster(
input_spike,
dt=self.dt,
num_channels=self.size_in,
)
E = ts_input.raster(dt=self.frame_dt, add_events=True)
num_frames = E.shape[0]
# add the effect of initial values in E_frame_counter
E[0, :] += self.E_frame_counter.astype(int)
# clip the counter to take the limited number of bits into account
E = np.clip(E, 0, 2**self.bits_counter)
# Reset the value of E_frame_counter
self.E_frame_counter = np.zeros(self.size_in, "int")
# Perform low-pass filter on E(t)-> M(t)
# M(t) = s * E(t) + (1-s) M(t-1)
# with s=1/2**bits_shift_lowpass
# M: (n_frame, self.size_in) -> units of 'frame_dt'
M = np.zeros((num_frames + 1, self.size_in), dtype="int")
# - Select the low-pass implementation
if self.low_pass_mode is LowPassMode.UNDERFLOW_PROTECT:
low_pass = self._low_pass_underflow_protect
elif self.low_pass_mode is LowPassMode.OVERFLOW_PROTECT:
low_pass = self._low_pass_overflow_protect
else:
raise ValueError(
f"Unexpected value for `.low_pass_mode`: {self.low_pass_mode}. Expected {[str(e) for e in LowPassMode]}"
)
# load the initialization of the filter
M[0, :] = self.M_lowpass_state
# - Perform the low-pass filtering
for t in range(num_frames):
M[t + 1, :] = low_pass(E[t, :], M[t, :])
# - Trim the first entry (initial state)
M = M[1:, :]
# take the limited number of counter bits into account
# we should make sure that the controller does not allow count-back to zero
# i.e., it keeps the value of the counter at its maximum
M = np.clip(M, 0, 2**self.bits_lowpass - 1)
self.M_lowpass_state = M[-1, :]
# use the value of E(t) at each frame t to produce a pseudo-random
# Poisson spike train by comparing E(t) with the LFSR output
# as the value of LFSR varies with global clock rate f_s, we have 'frame_dt*f_s'
# samples in each frame
# the timing of the output is in units of 'dt'
# Number of global clock cycles within a frame period
cycles_per_frame = int(np.ceil(self.frame_dt / self.dt))
# initialize the IAF_state for further inspection
if record:
IAF_state_saved = [[] for _ in range(self.size_in)]
# record output spikes and their channels
output_spike_times = [[] for _ in range(self.size_in)]
output_spike_channels = [[] for _ in range(self.size_in)]
# perform operation per channel
for ch in range(self.size_in):
# for each channel
# copy the input spike signal and zero-pad it at the end
# we have "cycles_per_frame" of global clock cycles and we make sure that
# the length of the input signal is an integer multiple of this
# due to return-to-zero pulse shape we need to add a zero
input_copy = np.zeros((num_frames * cycles_per_frame, 2))
input_copy[: input_spike.shape[0], 0] = input_spike[:, ch]
# now we need to expand each pulse by a factor p_local/2
input_copy = input_copy.repeat(
self.p_local / 2
) # (p_local/2*num_frames*cycles_per_frame) * 1
# and reshape the pulses into frames of size p_local*cycles_per_frame
# S_local --> local spike generator
S_local = input_copy.reshape(
-1, self.p_local * cycles_per_frame
) # num_frames * (p_local*cycles_per_frame)
# Note: after this step everything is just similar to the implementation with LFSR
# apply IAF with threshold M(t) at each frame t (each row of S_local)
# due to surplus from frame t-> t+1, we need to do this frame by frame
# output_spike_ch = np.zeros(S_local.shape, dtype="int")
# initialize the state of the corresponding counter
# res_from_previous_frame = self.IAF_counter[ch]
for t in range(num_frames):
# find the largest integer less than the floating-value threshold M(t)
# this way of thresholding works because the IAF is implemented by counter
# so, we need to set the value of threshold to be an integer value
# some care is needed when M(t)<1 because then IAF fires everytime a
# spike comes from the local generator
# we solve this by simply adding the threshold by 1
threshold = M[t, ch] + 1
# compute the cumulative number of firings starting from residual and take mode
IAF_state = (
np.concatenate(([self.IAF_counter[ch]], np.cumsum(S_local[t, :])))
% threshold
)
# IAF_state = (
# self.IAF_counter[ch] + np.cumsum(S_local[t, :])
# ) % threshold
# save if needed
if record:
IAF_state_saved[ch].append(np.copy(IAF_state[1:]))
# to find the firing times, we need to find those times for which IAF_state[t]-IAF_state[t+1]<0
# +1 is needed because of the delay we added
firing_times_in_frame = np.argwhere(
(IAF_state[1:] - IAF_state[0:-1]) < 0
).reshape(-1)
# register these firing times
# output_spike_ch[t, firing_times_in_frame] = 1
output_spike_times[ch].append(
firing_times_in_frame * (self.dt / self.p_local) + t * self.frame_dt
)
# Save the IAF state for the next frame
self.IAF_counter[ch] = np.copy(IAF_state[-1])
# res_from_previous_frame = output_spike_ch[t, -1]
# register the state of the IAF counter
# self.IAF_counter[ch] = res_from_previous_frame
# unwrap the spikes and copy it in the output_spike for the channel
# since we are not worried about modification: no need for copy -> ravel
# output_spike[:, ch] = output_spike_ch.ravel()
if record:
IAF_state_saved[ch] = np.concatenate(IAF_state_saved[ch])
# - Build a channels list for the spikes for this channel
output_spike_times[ch] = np.concatenate(output_spike_times[ch])
output_spike_channels[ch] = ch * np.ones(
len(output_spike_times[ch]), dtype="int"
)
# - Sort times and channels
sorted_indices = np.argsort(np.concatenate(output_spike_times))
output_spike_times = np.concatenate(output_spike_times)[sorted_indices]
output_spike_channels = np.concatenate(output_spike_channels)[sorted_indices]
# - Build output spike raster via TSEvent
output_spike = TSEvent(
output_spike_times,
output_spike_channels,
t_start=0.0,
t_stop=num_frames * self.frame_dt,
num_channels=self.size_in,
).raster(self.dt, add_events=True)
# - Generate state record dictionary
record_dict = (
{
"E": E,
"M": M,
"IAF_state": np.array(IAF_state_saved).T,
}
if record
else {}
)
return output_spike, self.state(), record_dict