"""
XyloSim-backed module compatible with Xylo. Requires XyloSim
"""
# - Rockpool imports
from rockpool.nn.modules.module import Module
from rockpool.parameters import Parameter, State, SimulationParameter
from rockpool import TSContinuous, TSEvent
from rockpool.utilities.backend_management import backend_available
from xylosim.v1 import XyloLayer
# - Numpy
import numpy as np
# - Typing
from typing import Optional, Union, Any, Dict
XyloConfiguration = Union[Dict, Any]
# - Define exports
__all__ = ["XyloSim"]
class XyloSim(Module):
"""
A :py:class:`.Module` simulating a digital SNN on Xylo, using XyloSim as a back-end.
You should use the factory methods `.from_config` and `.from_specification` to build a concrete `.XyloSim` module.
See Also:
See the tutorials :ref:`/devices/xylo-overview.ipynb` and :ref:`/devices/torch-training-spiking-for-xylo.ipynb` for a high-level overview of building and deploying networks for Xylo.
"""
__create_key = object()
output_mode = "Spike"
""" Private key to ensure factory creation """
[docs] def __init__(
self,
create_key,
config: XyloConfiguration,
shape: tuple = (16, 1000, 8),
dt: float = 1e-3,
output_mode: str = "Spike",
*args,
**kwargs,
):
"""
Private constructor for :py:class:`.XyloSim`
Warnings:
Use the factory methods :py:meth:`.XyloSim.from_config` and :py:meth:`XyloSim.from_specfication` to construct a :py:class:`.XyloSim` module.
"""
# - Check that we are creating the object using a factory function
if create_key is not XyloSim.__create_key:
raise NotImplementedError(
"XyloSim may only be instantiated using factory methods `from_config` or `from_specification`."
)
# - Initialise the superclass
super().__init__(
shape=shape,
spiking_input=True,
spiking_output=True,
*args,
**kwargs,
)
# - Store the configuration
self.config: Union[XyloConfiguration, Parameter] = Parameter(
shape=(), init_func=lambda _: config
)
""" (XyloConfiguration) Configuration of the Xylo module """
# - Store the dt
self.dt: Union[float, SimulationParameter] = SimulationParameter(dt)
""" (float) Simulation time-step for this module """
# - Empty attribute for the Xylo layer
self._xylo_layer: Optional[XyloLayer] = None
""" (XyloLayer) Handle to a XyloSim object """
# - Readout mode
assert output_mode in [
"Isyn",
"Vmem",
"Spike",
], f"{output_mode} is not supported."
self.output_mode = output_mode
@classmethod
def from_config(
cls, config: XyloConfiguration, dt: float = 1e-3, output_mode: str = "Spike"
):
"""
Creata a XyloSim based layer to simulate the Xylo hardware, from a configuration
Parameters:
dt: float
Timestep for simulation, in seconds. Default: 1ms
config: XyloConfiguration
``samna.xylo.XyloConfiguration`` object to specify all parameters. See samna documentation for details.
"""
cls.output_mode = output_mode
# - Import XyloSim
from xylosim.v1 import XyloSynapse, XyloLayer
# - Instantiate the class
mod = cls(
create_key=cls.__create_key,
config=config,
dt=dt,
output_mode=cls.output_mode,
)
# - Make a storage object for the extracted configuration
class _(object):
pass
_xylo_sim_params = _()
# - Convert input weights to XyloSynapse objects
_xylo_sim_params.synapses_in = []
for pre, w_pre in enumerate(config.input.weights):
tmp = []
for post in np.where(w_pre)[0]:
tmp.append(XyloSynapse(post, 0, w_pre[post]))
if config.synapse2_enable:
w2_pre = config.input.syn2_weights[pre]
for post in np.where(w2_pre)[0]:
tmp.append(XyloSynapse(post, 1, w2_pre[post]))
_xylo_sim_params.synapses_in.append(tmp)
# - Convert recurrent weights to XyloSynapse objects
_xylo_sim_params.synapses_rec = []
for pre, w_pre in enumerate(config.reservoir.weights):
tmp = []
for post in np.where(w_pre)[0]:
tmp.append(XyloSynapse(post, 0, w_pre[post]))
if config.synapse2_enable:
w2_pre = config.reservoir.syn2_weights[pre]
for post in np.where(w2_pre)[0]:
tmp.append(XyloSynapse(post, 1, w2_pre[post]))
_xylo_sim_params.synapses_rec.append(tmp)
# - Convert output weights to XyloSynapse objects
_xylo_sim_params.synapses_out = []
for pre, w_pre in enumerate(config.readout.weights):
tmp = []
for post in np.where(w_pre)[0]:
tmp.append(XyloSynapse(post, 0, w_pre[post]))
_xylo_sim_params.synapses_out.append(tmp)
# - Configure reservoir neurons
_xylo_sim_params.threshold = []
_xylo_sim_params.dash_syn = []
_xylo_sim_params.dash_mem = []
_xylo_sim_params.aliases = []
for neuron in config.reservoir.neurons:
if neuron.alias_target:
_xylo_sim_params.aliases.append([neuron.alias_target])
else:
_xylo_sim_params.aliases.append([])
_xylo_sim_params.threshold.append(neuron.threshold)
_xylo_sim_params.dash_mem.append(neuron.v_mem_decay)
_xylo_sim_params.dash_syn.append([neuron.i_syn_decay, neuron.i_syn2_decay])
# - Configure readout neurons
_xylo_sim_params.threshold_out = []
_xylo_sim_params.dash_syn_out = []
_xylo_sim_params.dash_mem_out = []
for neuron in config.readout.neurons:
_xylo_sim_params.threshold_out.append(neuron.threshold)
_xylo_sim_params.dash_mem_out.append(neuron.v_mem_decay)
_xylo_sim_params.dash_syn_out.append([neuron.i_syn_decay])
_xylo_sim_params.weight_shift_inp = config.input.weight_bit_shift
_xylo_sim_params.weight_shift_rec = config.reservoir.weight_bit_shift
_xylo_sim_params.weight_shift_out = config.readout.weight_bit_shift
# - Instantiate a Xylo Simulation layer
mod._xylo_layer = XyloLayer(
synapses_in=_xylo_sim_params.synapses_in,
synapses_rec=_xylo_sim_params.synapses_rec,
synapses_out=_xylo_sim_params.synapses_out,
aliases=_xylo_sim_params.aliases,
threshold=_xylo_sim_params.threshold,
threshold_out=_xylo_sim_params.threshold_out,
weight_shift_inp=_xylo_sim_params.weight_shift_inp,
weight_shift_rec=_xylo_sim_params.weight_shift_rec,
weight_shift_out=_xylo_sim_params.weight_shift_out,
dash_mem=_xylo_sim_params.dash_mem,
dash_mem_out=_xylo_sim_params.dash_mem_out,
dash_syns=_xylo_sim_params.dash_syn,
dash_syns_out=_xylo_sim_params.dash_syn_out,
name="XyloSim_XyloLayer",
)
# - Store parameters and return
mod._xylo_sim_params = _xylo_sim_params
return mod
@classmethod
def from_specification(
cls,
weights_in: np.ndarray,
weights_out: np.ndarray,
weights_rec: Optional[np.ndarray] = None,
dash_mem: Optional[np.ndarray] = None,
dash_mem_out: Optional[np.ndarray] = None,
dash_syn: Optional[np.ndarray] = None,
dash_syn_2: Optional[np.ndarray] = None,
dash_syn_out: Optional[np.ndarray] = None,
threshold: Optional[np.ndarray] = None,
threshold_out: Optional[np.ndarray] = None,
weight_shift_in: int = 0,
weight_shift_rec: int = 0,
weight_shift_out: int = 0,
aliases: Optional[list] = None,
dt: float = 1e-3,
verify_config: bool = True,
output_mode: str = "Spike",
) -> "XyloSim":
"""
Instantiate a :py:class:`.XyloSim` module from a full set of parameters
Args:
weights_in (np.ndarray): An int8 matrix ``(Nin, Nhidden, 2)``, specifying input to hidden neuron connections. The final dimension specifies the inputs to the two available synapses of the hidden neurons.
weights_out (np.ndarray): An int8 matrix ``(Nhidden, Nout)``, specifying hidden to output connections.
weights_rec (Optional[np.ndarray]): An int8 matrix ``(Nhidden, Nhidden, 2)``, specifying recurrent connections within the hidden population. The final dimension specifies the input to the two available synapses on each hidden neuron. Default: ``0``, no recurrent connections.
dash_mem (Optional[np.ndarray]): An int8 matrix ``(Nhidden)``, specifying the bitshift decay value for each hidden neuron membrane potential. Default: ``1``.
dash_mem_out (Optional[np.ndarray]): An int8 matrix ``(Nout)``, specifying the bitshift decay value for each output neuron membrane potential. Default: ``1``.
dash_syn (Optional[np.ndarray]): An int8 matrix ``(Nhidden)``, specifying the bitshift decay value for each hidden neuron synaptic current number 1. Default: ``1``.
dash_syn_2 (Optional[np.ndarray]): An int8 matrix ``(Nhidden)``, specifying the bitshift decay value for each hidden neuron synaptic current number 2. Default: ``1``.
dash_syn_out (Optional[np.ndarray]): An int8 matrix ``(Nout)``, specifying the bitshift decay value for each output neuron synaptic current. Default: ``1``.
threshold (Optional[np.ndarray]): An int8 matrix ``(Nhidden)``, specifying the firing threshold for each hidden neuron. Default: ``0``.
threshold_out (Optional[np.ndarray]): An int8 matrix ``(Nhidden)``, specifying the firing threshold for each output neuron. Default: ``0``.
weight_shift_in (int): An integer number of bits to left-shift the input weight matrix
weight_shift_rec (int): An integer number of bits to left-shift the hidden weight matrix
weight_shift_out (int): An integer number of bits to left-shift the output weight matrix
aliases (Optional[list]):
dt (float): Simulation time step in seconds. Default: 1 ms
verify_config (bool): Check for a valid configuraiton before applying it. Default ``True``.
Returns:
:py:class:`.XyloSim`: A :py:class:`.Module` that emulates the Xylo hardware.
Raises:
ValueError: If ``verify_config`` is ``True`` and the configuration is not valid.
"""
if not backend_available("samna"):
raise ModuleNotFoundError(
"`samna` not installed. `samna` is required to generate and validate a HW configuration for Xylo."
)
cls.output_mode = output_mode
from rockpool.devices.xylo.syns61300 import config_from_specification
# - Convert specification to xylo configuration
config, is_valid, status = config_from_specification(
weights_in=weights_in,
weights_rec=weights_rec,
weights_out=weights_out,
dash_mem=dash_mem,
dash_mem_out=dash_mem_out,
dash_syn=dash_syn,
dash_syn_2=dash_syn_2,
dash_syn_out=dash_syn_out,
threshold=threshold,
threshold_out=threshold_out,
weight_shift_in=weight_shift_in,
weight_shift_rec=weight_shift_rec,
weight_shift_out=weight_shift_out,
aliases=aliases,
)
if verify_config and not is_valid:
raise ValueError("Xylo configuration is not valid: " + status)
# - Instantiate module from config
return cls.from_config(config, dt=dt, output_mode=cls.output_mode)
def evolve(
self,
input_raster: np.ndarray = None,
record: bool = False,
*args,
**kwargs,
):
# - Evolve using the xylo layer
spikes_out = np.array(self._xylo_layer.evolve(input_raster))
if self.output_mode == "Spike":
output = spikes_out
elif self.output_mode == "Vmem":
output = np.array(self._xylo_layer.rec_v_mem_out).T
elif self.output_mode == "Isyn":
output = np.array(self._xylo_layer.rec_i_syn_out).T
# - Build the recording dictionary
if not record:
recording = {}
else:
recording = {
"Vmem": np.array(self._xylo_layer.rec_v_mem).T,
"Isyn": np.array(self._xylo_layer.rec_i_syn).T,
"Isyn2": np.array(self._xylo_layer.rec_i_syn2).T,
"Spikes": np.array(self._xylo_layer.rec_recurrent_spikes),
"Vmem_out": np.array(self._xylo_layer.rec_v_mem_out).T,
"Isyn_out": np.array(self._xylo_layer.rec_i_syn_out).T,
}
# - Return output, state and recording dictionary
return output, {}, recording
def reset_state(self) -> "XyloSim":
"""Reset the state of this module."""
self._xylo_layer.reset_all()
return self
def _wrap_recorded_state(self, state_dict: dict, t_start: float = 0.0) -> dict:
args = {"dt": self.dt, "t_start": t_start}
return {
"Vmem": TSContinuous.from_clocked(
state_dict["Vmem"], name="$V_{mem}$", **args
),
"Isyn": TSContinuous.from_clocked(
state_dict["Isyn"], name="$I_{syn}$", **args
),
"Isyn2": TSContinuous.from_clocked(
state_dict["Isyn2"], name="$I_{syn,2}$", **args
),
"Spikes": TSEvent.from_raster(state_dict["Spikes"], name="Spikes", **args),
"Vmem_out": TSContinuous.from_clocked(
state_dict["Vmem_out"], name="$V_{mem,out}$", **args
),
"Isyn_out": TSContinuous.from_clocked(
state_dict["Isyn_out"], name="$I_{syn,out}$", **args
),
}