Source code for devices.dynapse.dynapsim_net.from_spec

"""
Obtain a DynapSim network from the spec output of the mapper
See also `rockpool.devices.dynapse.mapper`
"""

from __future__ import annotations
from typing import List, Optional

import numpy as np

from rockpool.nn.modules.module import ModuleBase
from rockpool.nn.modules import LinearJax
from rockpool.nn.combinators import JaxSequential

from rockpool.devices.dynapse.simulation.dynapsim import DynapSim
from rockpool.devices.dynapse.lookup import default_weights

from rockpool.typehints import FloatVector

__all__ = ["dynapsim_net_from_spec"]


[docs]def dynapsim_net_from_spec( n_cluster: int, core_map: List[int], weights_in: Optional[FloatVector], weights_rec: Optional[FloatVector], # params Idc: List[FloatVector], If_nmda: List[FloatVector], Igain_ahp: List[FloatVector], Igain_mem: List[FloatVector], Igain_syn: List[FloatVector], Ipulse_ahp: List[FloatVector], Ipulse: List[FloatVector], Iref: List[FloatVector], Ispkthr: List[FloatVector], Itau_ahp: List[FloatVector], Itau_mem: List[FloatVector], Itau_syn: List[FloatVector], Iw_ahp: List[FloatVector], # definitions Iscale: float = default_weights["Iscale"], percent_mismatch: Optional[float] = None, dt: float = 1e-3, *args, **kwargs, ) -> ModuleBase: """ dynapsim_net_from_specification gets a specification and creates a sequential dynapsim network consisting of a linear layer (virtual connections) and a recurrent layer (hardware connections) :param n_cluster: total number of clusters, neural cores allocated :type n_cluster: int :param core_map: core map (neuron_id : core_id) for in-device neurons, defaults to CORE_MAP :type core_map: List[int] :param weights_in: a list of quantized input weight matrices :type weights_in: Optional[FloatVector] :param weights_rec: a list of quantized recurrent weight matrices :type weights_rec: Optional[FloatVector] :param Idc: a list of Constant DC current injected to membrane in Amperes :type Idc: List[FloatVector] :param If_nmda: a list of NMDA gate soft cut-off current setting the NMDA gating voltage in Amperes :type If_nmda: List[FloatVector] :param Igain_ahp: a list of gain bias current of the spike frequency adaptation block in Amperes :type Igain_ahp: List[FloatVector] :param Igain_mem: a list of gain bias current for neuron membrane in Amperes :type Igain_mem: List[FloatVector] :param Igain_syn: a list of gain bias current of synaptic gates (AMPA, GABA, NMDA, SHUNT) combined in Amperes :type Igain_syn: List[FloatVector] :param Ipulse_ahp: a list of bias current setting the pulse width for spike frequency adaptation block ```t_pulse_ahp``` in Amperes :type Ipulse_ahp: List[FloatVector] :param Ipulse: a list of bias current setting the pulse width for neuron membrane ```t_pulse``` in Amperes :type Ipulse: List[FloatVector] :param Iref: a list of bias current setting the refractory period ```t_ref``` in Amperes :type Iref: List[FloatVector] :param Ispkthr: a list of spiking threshold current, neuron spikes if :math:`I_{mem} > I_{spkthr}` in Amperes :type Ispkthr: List[FloatVector] :param Itau_ahp: a list of Spike frequency adaptation leakage current setting the time constant ```tau_ahp``` in Amperes :type Itau_ahp: List[FloatVector] :param Itau_mem: a list of Neuron membrane leakage current setting the time constant ```tau_mem``` in Amperes :type Itau_mem: List[FloatVector] :param Itau_syn: a list of (AMPA, GABA, NMDA, SHUNT) synapses combined leakage current setting the time constant ```tau_syn``` in Amperes :type Itau_syn: List[FloatVector] :param Iw_ahp: a list of spike frequency adaptation weight current of the neurons of the core in Amperes :type Iw_ahp: List[FloatVector] :param Iscale: network weight scaling current, defaults to default_weights["Iscale"] :type Iscale: float, optional :param percent_mismatch: Gaussian parameter mismatch percentage (check `transform.mismatch_generator` implementation), defaults to None :type percent_mismatch: Optional[float], optional :param dt: The time step for the forward-Euler ODE solver, defaults to 1e-3 :type dt: float, optional :return: a `nn.combinators.Sequential` combinator possibly encapsulating a `nn.modules.LinearJax` layer and a `DynapSim` layer, or just a `DynapSim` layer in the case that no input weights defined :rtype: `nn.modules.JaxModule` """ Idc_unc = np.zeros_like(core_map, dtype=float) If_nmda_unc = np.zeros_like(core_map, dtype=float) Igain_ahp_unc = np.zeros_like(core_map, dtype=float) Igain_mem_unc = np.zeros_like(core_map, dtype=float) Igain_syn_unc = np.zeros_like(core_map, dtype=float) Ipulse_ahp_unc = np.zeros_like(core_map, dtype=float) Ipulse_unc = np.zeros_like(core_map, dtype=float) Iref_unc = np.zeros_like(core_map, dtype=float) Ispkthr_unc = np.zeros_like(core_map, dtype=float) Itau_ahp_unc = np.zeros_like(core_map, dtype=float) Itau_mem_unc = np.zeros_like(core_map, dtype=float) Itau_syn_unc = np.zeros_like(core_map, dtype=float) Iw_ahp_unc = np.zeros_like(core_map, dtype=float) for i in range(n_cluster): np.place(Idc_unc, core_map == i, Idc[i]) np.place(If_nmda_unc, core_map == i, If_nmda[i]) np.place(Igain_ahp_unc, core_map == i, Igain_ahp[i]) np.place(Igain_mem_unc, core_map == i, Igain_mem[i]) np.place(Igain_syn_unc, core_map == i, Igain_syn[i]) np.place(Ipulse_ahp_unc, core_map == i, Ipulse_ahp[i]) np.place(Ipulse_unc, core_map == i, Ipulse[i]) np.place(Iref_unc, core_map == i, Iref[i]) np.place(Ispkthr_unc, core_map == i, Ispkthr[i]) np.place(Itau_ahp_unc, core_map == i, Itau_ahp[i]) np.place(Itau_mem_unc, core_map == i, Itau_mem[i]) np.place(Itau_syn_unc, core_map == i, Itau_syn[i]) np.place(Iw_ahp_unc, core_map == i, Iw_ahp[i]) weights_in = np.array(weights_in) if weights_in is not None else None weights_rec = np.array(weights_rec) if weights_rec is not None else None # Construct the layers in_layer = ( LinearJax(weights_in.shape, weights_in, has_bias=False) if weights_in is not None else None ) n_rec = len(core_map) dynapsim_layer = DynapSim( shape=(n_rec, n_rec), Idc=Idc_unc, If_nmda=If_nmda_unc, Igain_ahp=Igain_ahp_unc, Igain_mem=Igain_mem_unc, Igain_syn=Igain_syn_unc, Ipulse_ahp=Ipulse_ahp_unc, Ipulse=Ipulse_unc, Iref=Iref_unc, Ispkthr=Ispkthr_unc, Itau_ahp=Itau_ahp_unc, Itau_mem=Itau_mem_unc, Itau_syn=Itau_syn_unc, has_rec=True if weights_rec is not None else False, w_rec=weights_rec, percent_mismatch=percent_mismatch, Iscale=Iscale, dt=dt, ) # The resulting sequential module ! :tada: if in_layer is None: mod = dynapsim_layer else: mod = JaxSequential(in_layer, dynapsim_layer) return mod