Source code for devices.dynapse.dynapsim_net.from_config.getter

"""
Dynap-SE simulator generator from the samna config object
"""

from typing import Dict, Optional

import numpy as np

from rockpool.nn.modules import JaxModule
from rockpool.nn.modules import LinearJax
from rockpool.nn.combinators import JaxSequential

from rockpool.devices.dynapse.samna_alias import (
    Dynapse2Configuration,
    Dynapse2Destination,
)
from rockpool.devices.dynapse.parameters import DynapSimCore
from rockpool.devices.dynapse.lookup import default_weights
from rockpool.devices.dynapse.simulation import DynapSim

from .parameter import ParameterHandler

__all__ = ["dynapsim_net_from_config"]


[docs]def dynapsim_net_from_config( config: Dynapse2Configuration, input_channel_map: Optional[Dict[int, Dynapse2Destination]] = None, Iscale: float = default_weights["Iscale"], percent_mismatch: Optional[float] = None, dt: float = 1e-3, *args, **kwargs, ) -> JaxModule: """ dynapsim_net_from_config constructs a `DynapSim` network by processing a samna configuration object :param config: a samna configuration object used to configure all the system level properties :type config: Dynapse2Configuration :param input_channel_map: the mapping between input timeseries channels and the destinations, Providing an input channel map restores the zero rows that may occur in the simulation input weight matrices. The exact same dimensional weight matrix can only be restored in this way :type input_channel_map: Dict[int, Dynapse2Destination] :param Iscale: network weight scaling current, defaults to default_weights["Iscale"] :type Iscale: float, optional :param percent_mismatch: Gaussian parameter mismatch percentage (check `transforms.mismatch_generator` implementation), defaults to None :type percent_mismatch: Optional[float], optional :param dt: The time step for the forward-Euler ODE solver, defaults to 1e-3 :type dt: float, optional :return: a `nn.combinators.Sequential` combinator possibly encapsulating a `nn.modules.LinearJax` layer and a `DynapSim` layer, or just a `DynapSim` layer in the case that no input weights defined :rtype: `nn.modules.JaxModule` """ # Empty parameter lists Idc = [] If_nmda = [] Igain_ahp = [] Igain_mem = [] Igain_syn = [] Ipulse_ahp = [] Ipulse = [] Iref = [] Ispkthr = [] Itau_ahp = [] Itau_mem = [] Itau_syn = [] Iw_ahp = [] Iw_trace = [] # Get a parameter handler object which will lead the simulation network configuration param_handler = ParameterHandler.from_config(config) # Get a simulation core object for each represented core sim_cores = {} for h, c in param_handler.core_list: sim_cores[(h, c)] = DynapSimCore.from_Dynapse2Core(config.chips[h].cores[c]) # Collect currents of neurons of respective cores in parameter lists for n, (h, c) in enumerate(param_handler.core_map): Idc.append(sim_cores[(h, c)].Idc) If_nmda.append(sim_cores[(h, c)].If_nmda) Igain_ahp.append(sim_cores[(h, c)].Igain_ahp) Igain_mem.append(sim_cores[(h, c)].Igain_mem) Igain_syn.append(param_handler.compose_Igain_syn(sim_cores[(h, c)], n)) Ipulse_ahp.append(sim_cores[(h, c)].Ipulse_ahp) Ipulse.append(sim_cores[(h, c)].Ipulse) Iref.append(sim_cores[(h, c)].Iref) Ispkthr.append(sim_cores[(h, c)].Ispkthr) Itau_ahp.append(sim_cores[(h, c)].Itau_ahp) Itau_mem.append(sim_cores[(h, c)].Itau_mem) Itau_syn.append(param_handler.compose_Itau_syn(sim_cores[(h, c)], n)) Iw_ahp.append(sim_cores[(h, c)].Iw_ahp) Iw_trace.append([sim_cores[(h, c)].weight_bits.Iw]) # Get restored and scaled weight matrices using the Iw traces of the neurons w_in_scaled = param_handler.get_scaled_weights_in(Iw_trace, Iscale) w_rec_scaled = param_handler.get_scaled_weights_rec(Iw_trace, Iscale) # Input layer (external world -> hardware) if w_in_scaled.any(): # Restore zero rows if input_channel_map is not None: w_in_scaled = __restore_zero_rows(w_in_scaled, input_channel_map) in_layer = LinearJax(w_in_scaled.shape, w_in_scaled, has_bias=False) else: in_layer = None # Recurrent layer (hardware -> hardware) dynapsim_layer = DynapSim( shape=param_handler.n_rec, Idc=np.array(Idc), If_nmda=np.array(If_nmda), Igain_ahp=np.array(Igain_ahp), Igain_mem=np.array(Igain_mem), Igain_syn=np.array(Igain_syn), Ipulse_ahp=np.array(Ipulse_ahp), Ipulse=np.array(Ipulse), Iref=np.array(Iref), Ispkthr=np.array(Ispkthr), Itau_ahp=np.array(Itau_ahp), Itau_mem=np.array(Itau_mem), Itau_syn=np.array(Itau_syn), Iw_ahp=np.array(Iw_ahp), w_rec=w_rec_scaled, has_rec=w_rec_scaled.any(), percent_mismatch=percent_mismatch, Iscale=Iscale, dt=dt, ) # The resulting sequential module ! :tada: if in_layer is not None: mod = JaxSequential(in_layer, dynapsim_layer) else: mod = dynapsim_layer return mod
### --- Private Section --- ### def __restore_zero_rows( trimmed_weights: np.ndarray, channel_map: Dict[int, Dynapse2Destination], ) -> np.ndarray: """ __restore_zero_rows restores the zero rows that occur in the simulation weights but lost in config object Configuration object does not store a connection information if all weights from one input channel to all neurons are zero However, input channel map stores. This function exploits the input channel map and restores the zero rows :param trimmed_weights: the weight matrix contatining only rows with at least one non-zero entitiy :type trimmed_weights: np.ndarray :param channel_map: the mapping between input timeseries channels and the hardware destinations :type channel_map: Dict[int, Dynapse2Destination] :return: the restored weight matrix with (possibly) zero rows. :rtype: np.ndarray """ # Obtain an empty matrix shape = (len(channel_map), trimmed_weights.shape[1]) weights = np.zeros(shape) # Re-shape the trimmed weight matrix non_zero_idx = [key for key, val in channel_map.items() if val] weights[non_zero_idx, :] = trimmed_weights return weights