"""
Implements the SynNet architecture for temporal signal processing
The SynNet architecture is described in Bos & Muir 2022 [https://arxiv.org/abs/2208.12991] and Bos & Muir 2024 [https://arxiv.org/abs/2406.15112]
"""
from rockpool.nn.modules import TorchModule
from rockpool.nn.modules import LinearTorch, LIFTorch, TimeStepDropout, LIFExodus
from rockpool.parameters import Constant
from rockpool.nn.modules.torch.lif_torch import tau_to_bitshift, bitshift_to_tau
from rockpool.nn.modules.torch.lif_torch import PeriodicExponential
from rockpool.nn.combinators.sequential import TorchSequential
import torch
from copy import copy
from typing import List, Type, Union, Optional, Dict
THRESHOLD_OUT = 2**15 - 1
__all__ = ["SynNet"]
[docs]class SynNet(TorchModule):
"""
Define a ``SynNet`` architecture network
This class wraps a ``SynNet`` network, as defined in [1, 2].
This is a feedforward spiking network architecture, with a range of synaptic time constants in each layer.
By default the time constants are constant (not trainable). This is modifiable with the ``train_time_constants`` argument.
[1] Bos & Muir 2022. "Sub-mW Neuromorphic SNN audio processing applications with Rockpool and Xylo." ESSCIRC2022. https://arxiv.org/abs/2208.12991
[2] Bos & Muir 2024. "Micro-power spoken keyword spotting on Xylo Audio 2." arXiv. https://arxiv.org/abs/2406.15112
"""
[docs] def __init__(
self,
n_classes: int,
n_channels: int,
size_hidden_layers: List = [60],
time_constants_per_layer: List = [10],
tau_syn_base: float = 2e-3,
tau_mem: float = 2e-3,
tau_syn_out: float = 2e-3,
quantize_time_constants: bool = True,
train_time_constants: bool = False,
threshold: float = 1.0,
threshold_out: Union[float, List[float]] = None,
train_threshold: bool = False,
neuron_model: Type = LIFTorch,
max_spikes_per_dt: int = 31,
max_spikes_per_dt_out: int = 1,
p_dropout: float = 0.0,
dt: float = 1e-3,
output: str = "spikes",
neuron_kwargs: Optional[Dict] = None,
*args,
**kwargs,
):
"""
Define a ``SynNet`` architecture network
Args:
n_classes (int): number of output classes
n_channels (int): number of input channels
size_hidden_layers (List[int]): list of number of neurons per layer, list has length ``number_of_layers``. Default: ``[60]``
time_constants_per_layer (List[float]): list of number of time synaptic constants per layer, list has length ``number_of_layers``. Default: ``[10]``
tau_syn_base (float): smallest synaptic time constants of hidden neurons in seconds. Default: 2ms, ``2e-3``
tau_syn_out (float): synaptic time constants of output neurons in seconds. Default: 2ms, ``2e-3``
tau_mem (float): membrane time constant of all neurons in seconds. Default: 2ms, ``2e-3``
quantize_time_constants (bool): If ``True``, initial time constants will be rounded to values compatibe with Xylo deployment. Default: ``True``
train_time_constants (bool): If ``True``, time constants will be trainable. Default: ``False``, do not train time constants
threshold (float): threshold of hidden neurons. Default: ``1.0``
threshold_out (Union[List[float], float]): thresholds of readout neurons, can only be set if output is spikes. Default: ``None``, use the same value as ``threshold``
train_threshold (bool): If ``True``, the threshold will be trainable. If ``False``, the threshold will be constant. Default: ``False``.
neuron_model (Type): neuron model used for all neurons. Default: :py:class:`LIFTorch`
max_spikes_per_dt (int): maximum number of spikes per time step of all neurons apart from output neurons. Default: ``31``
max_spikes_per_dt_out (int): maximum number of spikes per time step of output neurons. Default: ``1``
p_dropout (float): probability that each time step from each neuron is dropped during training. Default: ``0.0``
dt (float): time step for simulation in seconds. Currently the values of the time step and the time constants are not independent, thus it should be chosen carefully to allow for interpretable time constants. Default: 1ms, ``1e-3``
output (str): specification of output variable, one of ``['spikes', 'vmem']``. Default: ``spikes``
neuron_kwargs (Optional[Dict]): If supplied, keyword arguments from this dictionary will be passed to ``neuron_model`` on instantiation. Default: If :py:class:`LIFTorch` or :py:class:`LIFExodus` are used, the arguments will be ``{'spike_generation_fn':PeriodicExponential, 'max_spikes_per_dt': max_spikes_per_dt_out}``
"""
# - Initialise superclass
super().__init__(
shape=(n_channels, n_classes),
spiking_input=True,
spiking_output=output == "spikes",
*args,
**kwargs,
)
if len(size_hidden_layers) != len(time_constants_per_layer):
raise ValueError(
"lists for hidden layer sizes and number of time constants per layer need to have the same length"
)
if tau_syn_base <= dt:
raise ValueError(
"the base synaptic time constant tau_syn_base needs to be larger than the time step dt"
)
if output not in ["spikes", "vmem"]:
raise ValueError("output variable ", output, " not defined")
if output == "vmem" and threshold_out is not None:
raise ValueError(
"threshold of readout neurons is not applied if output is vmem (membrane potential)"
)
# - Select the output threshold
if output == "vmem":
threshold_out = THRESHOLD_OUT
else:
threshold_out = threshold_out if threshold_out is not None else threshold
self.output: str = output
""" str: The output generated by this network. One of ``['spike', 'vmem']`` """
self.dt: float = dt
""" float: The time constant used by this network on initialisation """
# round time constants to the values they will take when deploying to Xylo
if quantize_time_constants:
tau_mem_bitshift = torch.round(
tau_to_bitshift(dt, torch.tensor(tau_mem))[0]
).int()
tau_mem = bitshift_to_tau(dt, tau_mem_bitshift)[0].item()
tau_syn_out_bitshift = torch.round(
tau_to_bitshift(dt, torch.tensor(tau_syn_out))[0]
).int()
tau_syn_out = bitshift_to_tau(dt, tau_syn_out_bitshift)[0].item()
# - Make ``tau_mem`` constant, if it should not be trainable
tau_mem = tau_mem if train_time_constants else Constant(tau_mem)
# calculate how often time constants are repeated within a layer
tau_repetitions = []
for i, (n_hidden, n_tau) in enumerate(
zip(size_hidden_layers, time_constants_per_layer)
):
tau_repetitions.append(int(n_hidden / n_tau) + min(1, n_hidden % n_tau))
# - Define an empty Sequential network, to add each layer to
self.seq = TorchSequential()
""" Sequential: The network itself, as a ``Sequential`` Module """
# - Generate neuron arguments
if neuron_model in [LIFTorch, LIFExodus]:
default_neuron_kwargs = {
"spike_generation_fn": PeriodicExponential,
"max_spikes_per_dt": max_spikes_per_dt,
}
if neuron_kwargs is not None:
default_neuron_kwargs.update(neuron_kwargs)
default_out_neuron_kwargs = copy(default_neuron_kwargs)
default_out_neuron_kwargs.update(
{
"max_spikes_per_dt": max_spikes_per_dt_out,
}
)
# - Generate each set of weights and neurons in turn
n_channels_in = n_channels
lif_names = []
for i, (n_hidden, n_tau) in enumerate(
zip(size_hidden_layers, time_constants_per_layer)
):
# - Generate time constants
taus = [
torch.tensor(
[(tau_syn_base / dt) ** j * dt for j in range(1, n_tau + 1)]
)
for _ in range(tau_repetitions[i])
]
tau_syn_hidden = torch.hstack(taus)
# if size of layer is not a multiple of the time constants connections of different time constants are
# removed starting from the largest one
tau_syn_hidden = tau_syn_hidden[:n_hidden]
# round time constants to the values they will take when deploying to Xylo
if quantize_time_constants:
tau_syn_hidden_bitshift = [
torch.round(tau_to_bitshift(dt, tau_syn)[0]).int()
for tau_syn in tau_syn_hidden
]
tau_syn_hidden = torch.tensor(
[
bitshift_to_tau(dt, dash_syn)[0].item()
for dash_syn in tau_syn_hidden_bitshift
]
)
# - Generate a linear weight module
lyr_weights = LinearTorch(shape=(n_channels_in, n_hidden), has_bias=False)
n_channels_in = n_hidden
# - Normalise weights by time constant and add to network
with torch.no_grad():
lyr_weights.weight.data = lyr_weights.weight.data * dt / tau_syn_hidden
self.seq.append(lyr_weights, f"{i}_linear")
# - Add the neuron layer to the network
self.seq.append(
neuron_model(
shape=(n_hidden, n_hidden),
tau_mem=tau_mem,
tau_syn=(
tau_syn_hidden
if train_time_constants
else Constant(tau_syn_hidden)
),
bias=Constant(0.0),
threshold=threshold if train_threshold else Constant(threshold),
dt=dt,
**default_neuron_kwargs,
),
f"{i}_neurons",
)
lif_names.append(f"{i}_neurons")
# - Incorporate a dropout layer, if requested
if p_dropout > 0.0:
self.seq.append(
TimeStepDropout(shape=(n_hidden), p=p_dropout), f"{i}_dropout"
)
# - Add the output weight layer
lyr_weights = LinearTorch(shape=(n_hidden, n_classes), has_bias=False)
with torch.no_grad():
lyr_weights.weight.data = lyr_weights.weight.data * dt / tau_syn_out
self.seq.append(lyr_weights, "out_linear")
# - Add the output neuron layer
self.seq.append(
neuron_model(
shape=(n_classes, n_classes),
tau_mem=Constant(tau_mem),
tau_syn=Constant(tau_syn_out),
bias=Constant(0.0),
threshold=Constant(threshold_out),
dt=dt,
**default_out_neuron_kwargs,
),
"out_neurons",
)
lif_names.append(f"out_neurons")
# Record names of neuron and output layers
self.lif_names: List[str] = lif_names
""" List[str]: A list of the neuron models present in this network, in evolution order. """
self.label_last_LIF: str = lif_names[-1]
""" str: The name of the readout neuron layer in this network. """
# Dictionary for recording state
self._record: bool = False
""" bool: If ``True``, record the state trace during evolution """
self._record_dict: dict = {}
""" dict: The internal set of recorded state traces, if requested """
[docs] def forward(self, data: torch.Tensor):
# - Evolve the Sequential network
out, _, record_dict = self.seq(data, record=self._record)
# - If "vmem" output is requested, use this instead of spiking in the final layer
if self.output == "vmem":
out = record_dict[self.label_last_LIF]["vmem"]
# - Modify the record dictionary to store the output
if self._record:
for key in record_dict.keys():
if "output" in key:
record_dict[key] = out
# - Keep a copy of the record dictionary
self._record_dict = record_dict if self._record else {}
# - Return the model output
return out
[docs] def evolve(self, input_data, record: bool = False):
# - Store "record" state
self._record = record or self.output == "vmem"
# - Evolve network
output, new_state, record_dict = super().evolve(input_data, record=self._record)
# - Get recording dictionary
record_dict = record_dict if record else {}
# - Return
return output, new_state, record_dict
[docs] def as_graph(self):
# - Return the graph from the ``Sequential`` network
return self.seq.as_graph()