Source code for devices.dynapse.simulation.dynapsim

"""
Low level DynapSE-2 simulator neuron model implementation
Solves the characteristic equations to simulate the circuits with ptimizable parameters

References

[1] E. Chicca, F. Stefanini, C. Bartolozzi and G. Indiveri,
    "Neuromorphic Electronic Circuits for Building Autonomous Cognitive Systems,"
    in Proceedings of the IEEE, vol. 102, no. 9, pp. 1367-1388, Sept. 2014,
    doi: 10.1109/JPROC.2014.2313954.

[2] C. Bartolozzi and G. Indiveri, “Synaptic dynamics in analog vlsi,” Neural
    Comput., vol. 19, no. 10, p. 2581-2603, Oct. 2007. [Online]. Available:
    https://doi.org/10.1162/neco.2007.19.10.2581

[3] P. Livi and G. Indiveri, “A current-mode conductance-based silicon neuron for
    address-event neuromorphic systems,” in 2009 IEEE International Symposium on
    Circuits and Systems, May 2009, pp. 2898-2901

[4] Dynap-SE1 Neuromorphic Chip Simulator for NICE Workshop 2021
    https://code.ini.uzh.ch/yigit/NICE-workshop-2021

[5] Course: Neurormophic Engineering 1
    Tobi Delbruck, Shih-Chii Liu, Giacomo Indiveri
    https://tube.switch.ch/channels/88df64b6

[6] Course: 21FS INI508 Neuromorphic Intelligence
    Giacomo Indiveri
    https://tube.switch.ch/switchcast/uzh.ch/series/5ee1d666-25d2-4c4d-aeb9-4b754b880345?order=newest-first
"""

from __future__ import annotations

from typing import Any, Callable, Dict, Optional, Tuple, Union

import jax
from jax import random as rand
from jax.lax import scan
from jax.tree_util import Partial
from jax import numpy as jnp

import sys
import numpy as np

from rockpool.devices.dynapse.lookup import (
    default_layout,
    default_weights,
    default_currents,
)
from rockpool.devices.dynapse.typehints import DynapSimRecord, DynapSimState
from rockpool.devices.dynapse.mapping import DynapseNeurons

from rockpool.typehints import FloatVector
from rockpool.nn.modules.jax.jax_module import JaxModule
from rockpool.nn.modules.native.linear import kaiming
from rockpool.parameters import Parameter, State, SimulationParameter
from rockpool.graph import GraphHolder, LinearWeights, as_GraphHolder
from rockpool.transform.mismatch import mismatch_generator

from .surrogate import step_pwl
from .mismatch_prototype import frozen_mismatch_prototype

__all__ = ["DynapSim"]


[docs]class DynapSim(JaxModule): """ DynapSim solves dynamical chip equations for the DPI neuron and synapse models. Receives configuration as bias currents and solves membrane and synapse dynamics using ``jax`` backend. One block has * 1 synapse receiving spikes from the other circuits * 1 recurrent synapse for spike frequency adaptation (**AHP**) * 1 membrane evaluating the state and deciding fire or not For all the synapses, the ``DPI Synapse`` update equations below are solved in parallel. :DPI Synapse: .. math :: I_{syn}(t_1) = \\begin{cases} I_{syn}(t_0) \\cdot exp \\left( \\dfrac{-dt}{\\tau} \\right) &\\text{in any case} \\\\ \\\\ I_{syn}(t_1) + \\dfrac{I_{th} I_{w}}{I_{\\tau}} \\cdot \\left( 1 - exp \\left( \\dfrac{-t_{pulse}}{\\tau} \\right) \\right) &\\text{if a spike arrives} \\end{cases} Where .. math :: \\tau = \\dfrac{C U_{T}}{\\kappa I_{\\tau}} For the membrane update, the forward Euler solution below is applied. :Membrane: .. math :: dI_{mem} &= \\dfrac{I_{mem}}{\\tau \\left( I_{mem} + I_{th} \\right) } \\cdot \\left( I_{mem_{\\infty}} + f(I_{mem}) - I_{mem} \\left( 1 + \\dfrac{I_{ahp}}{I_{\\tau}} \\right) \\right) \\cdot dt \\\\\\\\ I_{mem}(t_1) &= I_{mem}(t_0) + dI_{mem} Where .. math :: I_{mem_{\\infty}} &= \\dfrac{I_{th}}{I_{\\tau}} \\left( I_{in} - I_{ahp} - I_{\\tau}\\right) \\\\\\\\ f(I_{mem}) &= \\dfrac{I_{a}}{I_{\\tau}} \\left(I_{mem} + I_{th} \\right ) \\\\\\\\ I_{a} &= \\dfrac{I_{a_{gain}}}{1+ exp\\left(-\\dfrac{I_{mem}+I_{a_{th}}}{I_{a_{norm}}}\\right)} \\\\\\\\ :On spiking: When the membrane potential for neuron :math:`j`, :math:`I_{mem, j}` exceeds the threshold current :math:`I_{spkthr}`, then the neuron emits a spike. .. math :: I_{mem, j} > I_{spkthr} \\rightarrow S_{j} &= 1 \\\\ I_{mem, j} &= I_{reset} \\\\ .. seealso :: For detailed explanations of the equations and the usage :ref:`/devices/DynapSE/neuron-model.ipynb` """
[docs] def __init__( self, shape: Union[Tuple[int], int], Idc: FloatVector = default_currents["Idc"], If_nmda: FloatVector = default_currents["If_nmda"], Igain_ahp: FloatVector = default_currents["Igain_ahp"], Igain_mem: FloatVector = default_currents["Igain_mem"], Igain_syn: FloatVector = default_currents["Igain_ampa"], Ipulse_ahp: FloatVector = default_currents["Ipulse_ahp"], Ipulse: FloatVector = default_currents["Ipulse"], Iref: FloatVector = default_currents["Iref"], Ispkthr: FloatVector = default_currents["Ispkthr"], Itau_ahp: FloatVector = default_currents["Itau_ahp"], Itau_mem: FloatVector = default_currents["Itau_mem"], Itau_syn: FloatVector = default_currents["Itau_ampa"], Iw_ahp: FloatVector = default_currents["Iw_ahp"], C_ahp: FloatVector = default_layout["C_ahp"], C_syn: FloatVector = default_layout["C_ampa"], C_pulse_ahp: FloatVector = default_layout["C_pulse_ahp"], C_pulse: FloatVector = default_layout["C_pulse"], C_ref: FloatVector = default_layout["C_ref"], C_mem: FloatVector = default_layout["C_mem"], Io: FloatVector = default_layout["Io"], kappa_n: FloatVector = default_layout["kappa_n"], kappa_p: FloatVector = default_layout["kappa_p"], Ut: FloatVector = default_layout["Ut"], Vth: FloatVector = default_layout["Vth"], Iscale: FloatVector = default_weights["Iscale"], w_rec: Optional[FloatVector] = None, has_rec: bool = False, weight_init_func: Optional[Callable[[Tuple], FloatVector]] = kaiming, dt: float = 1e-3, percent_mismatch: Optional[float] = None, rng_key: Optional[FloatVector] = None, spiking_input: bool = False, spiking_output: bool = True, *args, **kwargs, ) -> None: """ __init__ constructs a DynapSim object :param shape: Either a single dimension ``N``, which defines a feed-forward layer of DynapSE AdExpIF neurons, or two dimensions ``(N, N)``, which defines a recurrent layer of DynapSE AdExpIF neurons. :type shape: Tuple[int] :param Idc: Constant DC current injected to membrane in Amperes with shape :type Idc: FloatVector, optinoal :param If_nmda: NMDA gate soft cut-off current setting the NMDA gating voltage in Amperes with shape (Nrec,) :type If_nmda: FloatVector, optinoal :param Igain_ahp: gain bias current of the spike frequency adaptation block in Amperes with shape (Nrec,) :type Igain_ahp: FloatVector, optinoal :param Igain_mem: gain bias current for neuron membrane in Amperes with shape (Nrec,) :type Igain_mem: FloatVector, optinoal :param Igain_syn: gain bias current of synaptic gates (AMPA, GABA, NMDA, SHUNT) combined in Amperes with shape (Nrec,) :type Igain_syn: FloatVector, optinoal :param Ipulse_ahp: bias current setting the pulse width for spike frequency adaptation block ``t_pulse_ahp`` in Amperes with shape (Nrec,) :type Ipulse_ahp: FloatVector, optinoal :param Ipulse: bias current setting the pulse width for neuron membrane ``t_pulse`` in Amperes with shape (Nrec,) :type Ipulse: FloatVector, optinoal :param Iref: bias current setting the refractory period ``t_ref`` in Amperes with shape (Nrec,) :type Iref: FloatVector, optinoal :param Ispkthr: spiking threshold current, neuron spikes if :math:`I_{mem} > I_{spkthr}` in Amperes with shape (Nrec,) :type Ispkthr: FloatVector, optinoal :param Itau_ahp: Spike frequency adaptation leakage current setting the time constant ``tau_ahp`` in Amperes with shape (Nrec,) :type Itau_ahp: FloatVector, optinoal :param Itau_mem: Neuron membrane leakage current setting the time constant ``tau_mem`` in Amperes with shape (Nrec,) :type Itau_mem: FloatVector, optinoal :param Itau_syn: (AMPA, GABA, NMDA, SHUNT) synapses combined leakage current setting the time constant ``tau_syn`` in Amperes with shape (Nrec,) :type Itau_syn: FloatVector, optinoal :param Iw_ahp: spike frequency adaptation weight current of the neurons of the core in Amperes with shape (Nrec,) :type Iw_ahp: FloatVector, optinoal :param C_ahp: AHP synapse capacitance in Farads with shape (Nrec,) :type C_ahp: FloatVector, optional :param C_syn: synaptic capacitance in Farads with shape (Nrec,) :type C_syn: FloatVector, optional :param C_pulse_ahp: spike frequency adaptation circuit pulse-width creation sub-circuit capacitance in Farads with shape (Nrec,) :type C_pulse_ahp: FloatVector, optional :param C_pulse: pulse-width creation sub-circuit capacitance in Farads with shape (Nrec,) :type C_pulse: FloatVector, optional :param C_ref: refractory period sub-circuit capacitance in Farads with shape (Nrec,) :type C_ref: FloatVector, optional :param C_mem: neuron membrane capacitance in Farads with shape (Nrec,) :type C_mem: FloatVector, optional :param Io: Dark current in Amperes that flows through the transistors even at the idle state with shape (Nrec,) :type Io: FloatVector, optional :param kappa_n: Subthreshold slope factor (n-type transistor) with shape (Nrec,) :type kappa_n: FloatVector, optional :param kappa_p: Subthreshold slope factor (p-type transistor) with shape (Nrec,) :type kappa_p: FloatVector, optional :param Ut: Thermal voltage in Volts with shape (Nrec,) :type Ut: FloatVector, optional :param Vth: The cut-off Vgs potential of the transistors in Volts (not type specific) with shape (Nrec,) :type Vth: FloatVector, optional :param Iscale: weight scaling current of the neurons of the core in Amperes :type Iscale: FloatVector, optinoal :param w_rec: If the module is initialised in recurrent mode, one can provide a concrete initialisation for the recurrent weights, which must be a square matrix with shape ``(Nrec, Nrec, 4)``. The last 4 holds a weight matrix for 4 different synapse types. If the model is not initialised in recurrent mode, then you may not provide ``w_rec``, defaults tp None :type w_rec: Optional[FloatVector], optional :param has_rec: When ``True`` the module provides a trainable recurrent weight matrix. ``False``, module is feed-forward, defaults to True :type has_rec: bool, optional :param weight_init_func: The initialisation function to use when generating weights, gets the shape and returns the initial weights, defatuls to kaiming :type weight_init_func: Optional[Callable[[Tuple], FloatVector]], optional :param dt: The time step for the forward-Euler ODE solver, defaults to 1e-3 :type dt: float, optional :param percent_mismatch: Gaussian parameter mismatch percentage (check `transform.mismatch_generator` implementation), defaults to None :type percent_mismatch: Optional[float], optional :param rng_key: The Jax RNG seed to use on initialisation. By default, a new seed is generated, defaults to None :type rng_key: Optional[FloatVector], optional :param spiking_input: Whether this module receives spiking input, defaults to True :type spiking_input: bool, optional :param spiking_output: Whether this module produces spiking output, defaults to True :type spiking_output: bool, optional :raises ValueError: `shape` must be a one- or two-element tuple `(Nin, Nout)` :raises ValueError: Multapses are not currently supported in DynapSim pipeline! """ # - Check shape argument if np.size(shape) == 1: shape = (np.array(shape).item(), np.array(shape).item()) if np.size(shape) > 2: raise ValueError( "`shape` must be a one- or two-element tuple `(Nin, Nout)`." ) super(DynapSim, self).__init__( shape=shape, spiking_input=spiking_input, spiking_output=spiking_output, *args, **kwargs, ) if self.size_in != self.size_out: raise ValueError( "Multapses are not currently supported in DynapSim pipeline!" ) # - Seed RNG if rng_key is None: rng_key = rand.PRNGKey(np.random.randint(0, 2**63)) ### --- States --- #### __state = lambda init_func: State( init_func=init_func, shape=(self.size_out,), permit_reshape=False, cast_fn=lambda _o: jnp.array(_o, dtype=jnp.float32), ) __Io_state = lambda _: __state( lambda s: jnp.full(tuple(reversed(s)), Io, jnp.float32).T ) __zero_state = lambda _: __state(lambda s: jnp.zeros(s, dtype=jnp.float32)) ## Data self.iahp = __Io_state(None) """Spike frequency adaptation current states of the neurons in Amperes with shape (Nrec,)""" self.iampa = __Io_state(None) """Fast excitatory AMPA synapse current states of the neurons in Amperes with shape (Nrec,)""" self.igaba = __Io_state(None) """Slow inhibitory adaptation current states of the neurons in Amperes with shape (Nrec,)""" self.imem = __Io_state(None) """Membrane current states of the neurons in Amperes with shape (Nrec,)""" self.inmda = __Io_state(None) """Slow excitatory synapse current states of the neurons in Amperes with shape (Nrec,)""" self.ishunt = __Io_state(None) """Fast inhibitory shunting synapse current states of the neurons in Amperes with shape (Nrec,)""" self.spikes = __zero_state(None) """Logical spiking raster for each neuron at the last simulation time-step with shape (Nrec,)""" self.timer_ref = __zero_state(None) """timer to keep the time from the spike generation until the refractory period ends""" self.vmem = __zero_state(None) """Membrane potential states of the neurons in Volts with shape (Nrec,)""" ### --- Parameters --- ### __parameter = lambda _param: Parameter( data=( _param if isinstance( _param, (np.ndarray, jnp.ndarray, jax.Array, jax.core.Tracer) ) else jnp.full((self.size_out,), _param, dtype=jnp.float32) ), family="bias", shape=(self.size_out,), permit_reshape=False, cast_fn=lambda _o: jnp.array(_o, dtype=jnp.float32), ) # Special handler for wrec if isinstance(has_rec, jax.core.Tracer) or has_rec: self.w_rec = Parameter( data=w_rec, family="weights", init_func=weight_init_func, shape=(self.size_out, self.size_in), permit_reshape=False, cast_fn=lambda _o: jnp.array(_o, dtype=jnp.float32), ) else: # Do not let it break the pipeline self.w_rec = SimulationParameter( data=jnp.zeros((self.size_out, self.size_in), dtype=jnp.float32), family="weights", ) # --- Simulation Parameters --- # __simparam = lambda _param: SimulationParameter( data=( _param if isinstance( _param, (np.ndarray, jnp.ndarray, jax.Array, jax.core.Tracer) ) else jnp.full((self.size_out,), _param) ), shape=(self.size_out,), permit_reshape=False, cast_fn=lambda _o: jnp.array(_o, dtype=jnp.float32), ) # -- # self.Idc = __simparam(Idc) """Constant DC current injected to membrane in Amperes with shape""" self.If_nmda = __simparam(If_nmda) """NMDA gate soft cut-off current setting the NMDA gating voltage in Amperes with shape (Nrec,)""" self.Igain_ahp = __simparam(Igain_ahp) """gain bias current of the spike frequency adaptation block in Amperes with shape (Nrec,)""" self.Igain_mem = __simparam(Igain_mem) """gain bias current for neuron membrane in Amperes with shape (Nrec,)""" self.Igain_syn = __simparam(Igain_syn) """gain bias current of synaptic gates (AMPA, GABA, NMDA, SHUNT) combined in Amperes with shape (Nrec,)""" self.Ipulse_ahp = __simparam(Ipulse_ahp) """bias current setting the pulse width for spike frequency adaptation block ``t_pulse_ahp`` in Amperes with shape (Nrec,)""" self.Ipulse = __simparam(Ipulse) """bias current setting the pulse width for neuron membrane ``t_pulse`` in Amperes with shape (Nrec,)""" self.Iref = __simparam(Iref) """bias current setting the refractory period ``t_ref`` in Amperes with shape (Nrec,)""" self.Ispkthr = __simparam(Ispkthr) """spiking threshold current, neuron spikes if :math:`I_{mem} > I_{spkthr}` in Amperes with shape (Nrec,)""" self.Itau_ahp = __simparam(Itau_ahp) """Spike frequency adaptation leakage current setting the time constant ``tau_ahp`` in Amperes with shape (Nrec,)""" self.Itau_mem = __simparam(Itau_mem) """Neuron membrane leakage current setting the time constant ``tau_mem`` in Amperes with shape (Nrec,)""" self.Itau_syn = __simparam(Itau_syn) """(AMPA, GABA, NMDA, SHUNT) synapses combined leakage current setting the time constant ``tau_syn`` in Amperes with shape (Nrec,)""" self.Iw_ahp = __simparam(Iw_ahp) """spike frequency adaptation weight current of the neurons of the core in Amperes with shape (Nrec,)""" # -- # self.C_ahp = __simparam(C_ahp) """AHP synapse capacitance in Farads with shape (Nrec,)""" self.C_syn = __simparam(C_syn) """synaptic capacitance in Farads with shape (Nrec,)""" self.C_pulse_ahp = __simparam(C_pulse_ahp) """spike frequency adaptation circuit pulse-width creation sub-circuit capacitance in Farads with shape (Nrec,)""" self.C_pulse = __simparam(C_pulse) """pulse-width creation sub-circuit capacitance in Farads with shape (Nrec,)""" self.C_ref = __simparam(C_ref) """refractory period sub-circuit capacitance in Farads with shape (Nrec,)""" self.C_mem = __simparam(C_mem) """neuron membrane capacitance in Farads with shape (Nrec,)""" self.Io = __simparam(Io) """Dark current in Amperes that flows through the transistors even at the idle state with shape (Nrec,)""" self.kappa_n = __simparam(kappa_n) """Subthreshold slope factor (n-type transistor) with shape (Nrec,)""" self.kappa_p = __simparam(kappa_p) """Subthreshold slope factor (p-type transistor) with shape (Nrec,)""" self.Ut = __simparam(Ut) """Thermal voltage in Volts with shape (Nrec,)""" self.Vth = __simparam(Vth) """The cut-off Vgs potential of the transistors in Volts (not type specific) with shape (Nrec,)""" # -- # self.Iscale = SimulationParameter( np.array(Iscale, dtype=np.float32), shape=(1,) ) """weight scaling current of the neurons of the core in Amperes""" self.dt = SimulationParameter(np.array(dt, dtype=np.float32), shape=(1,)) """The time step for the forward-Euler ODE solver""" self.rng_key = State(rng_key, init_func=lambda _: rng_key) """The Jax RNG seed to use on initialisation. By default, a new seed is generated""" # One time mismatch if rng_key is None: rng_key = jnp.array( [np.random.randint(sys.maxsize, size=2)], dtype=jnp.uint32 ) if percent_mismatch is not None: rng_key, _ = rand.split(rng_key) prototype = frozen_mismatch_prototype(self) regenerate_mismatch = mismatch_generator( prototype=prototype, percent_deviation=percent_mismatch ) new_params = regenerate_mismatch(self, rng_key=rng_key) for key in new_params: self.__setattr__(key, new_params[key]) # - Define additional arguments required during initialisation self._init_args = { "has_rec": has_rec, "weight_init_func": Partial(weight_init_func), }
[docs] @classmethod def from_graph( cls, se: DynapseNeurons, weights: Optional[LinearWeights] = None ) -> DynapSim: """ from_graph constructs a ``DynapSim`` object from a computational graph :param se: the reference computational graph to restore the computational module :type se: DynapseNeurons :param weights: additional weights graph if one wants to impose recurrent weights, defaults to None :type weights: Optional[LinearWeights], optional :return: a ``DynapSim`` object :rtype: DynapSim """ if not isinstance(se, DynapseNeurons): se = DynapseNeurons._convert_from(se) if weights is not None: if weights.biases is not None: raise ValueError("Recurrent weight layer biases cannot be defined!") kwargs = {k: np.array(v) for k, v in se.get_full().items()} return cls( shape=(len(se.input_nodes), len(se.output_nodes)), Iscale=se.Iscale, w_rec=np.array(weights.weights) if weights is not None else None, has_rec=True if weights is not None else False, dt=se.dt, **kwargs, )
[docs] def evolve( self, input_data: FloatVector, record: bool = True ) -> Tuple[jax.Array, Dict[str, jax.Array], Dict[str, jax.Array]]: """ evolve implements raw rockpool JAX evolution function for a DynapSim module. The function solves the dynamical equations introduced at the ``DynapSim`` module definition :param input_data: Input array of shape ``(T, Nrec, 4)`` to evolve over. Represents number of spikes at that timebin for different synaptic gates :type input_data: FloatVector :param record: record the each timestep of evolution or not, defaults to True :type record: bool, optional :return: spikes_ts, states, record_dict :spikes_ts: is an array with shape ``(T, Nrec)`` containing the output data(spike raster) produced by the module. :states: is a dictionary containing the updated module state following evolution. :record_dict: is a dictionary containing the recorded state variables during the evolution at each time step, if the ``record`` argument is ``True`` else empty dictionary {} :rtype: Tuple[jax.Array, Dict[str, jax.Array], Dict[str, jax.Array]] """ kappa = (self.kappa_n + self.kappa_p) / 2 # --- Time constant computation utils --- # __pw = lambda ipw, C: (self.Vth * C) / ipw __tau = lambda itau, C: ((self.Ut / kappa) * C.T).T / itau tau_mem = lambda itau: __tau(itau, self.C_mem) # --- Stateless Parameters --- # t_ref = __pw(self.Iref, self.C_ref) t_pulse = __pw(self.Ipulse, self.C_pulse) t_pulse_ahp = __pw(self.Ipulse_ahp, self.C_pulse_ahp) ## --- Synapse --- ## Nrec Itau_syn_clip = jnp.clip(self.Itau_syn, self.Io) Igain_syn_clip = jnp.clip(self.Igain_syn, self.Io) tau_syn = __tau(Itau_syn_clip, self.C_syn) ## --- Spike frequency adaptation --- ## Nrec Itau_ahp_clip = jnp.clip(self.Itau_ahp, self.Io) Igain_ahp_clip = jnp.clip(self.Igain_ahp, self.Io) tau_ahp = __tau(Itau_ahp_clip, self.C_ahp) ## -- Membrane -- ## Nrec Itau_mem_clip = jnp.clip(self.Itau_mem, self.Io) Igain_mem_clip = jnp.clip(self.Igain_mem, self.Io) # Handle Batches initial_state = ( self.iahp, self.imem, self.iampa, self.rng_key, self.spikes, self.timer_ref, self.vmem, ) input_data, initial_state = self._auto_batch(input_data, initial_state) def forward( state: DynapSimState, ws_input: jax.Array ) -> Tuple[DynapSimState, DynapSimRecord]: """ forward implements single time-step neuron and synapse dynamics :param state: (iahp, iampa, igaba, imem, inmda, ishunt, rng_key, spikes, timer_ref, vmem) iahp: Spike frequency adaptation currents of each neuron [Nrec] imem: Membrane currents of each neuron [Nrec] inmda: sum of synapse currents of each neuron [Nrec] rng_key: The Jax RNG seed to be used for mismatch simulation spikes: Logical spike raster for each neuron [Nrec] timer_ref: Refractory timer of each neruon [Nrec] vmem: Membrane voltages of each neuron [Nrec] :type state: DynapSimState :param ws_input: weighted input spikes [Nrec, 4] :type ws_input: jax.Array :return: state, record state: Updated state at end of the forward steps record: Updated record instance to including spikes, igaba, ishunt, inmda, iampa, iahp, imem, and vmem states :rtype: Tuple[DynapSimState, DynapSimRecord] """ ( iahp, imem, isyn, rng_key, spikes, timer_ref, vmem, ) = state # ---------------------------------- # # --- Forward step: DPI SYNAPSES --- # # ---------------------------------- # ## Real time weight is 0 if no spike, w_rec if spike event occurs ws_rec = jnp.dot(self.w_rec.T, spikes).T # Nrec Iws = (ws_rec + ws_input) * self.Iscale # isyn_inf is the current that a synapse current would reach with a sufficiently long pulse isyn_inf = (Igain_syn_clip / Itau_syn_clip) * Iws isyn_inf = jnp.clip(isyn_inf, self.Io) ## Exponential charge, discharge positive feedback factor arrays f_charge = 1.0 - jnp.exp(-t_pulse / tau_syn.T).T # Nrecx4 f_discharge = jnp.exp(-self.dt / tau_syn) # Nrecx4 ## DISCHARGE in any case isyn = f_discharge * isyn ## CHARGE if spike occurs -- UNDERSAMPLED -- dt >> t_pulse isyn += f_charge * isyn_inf # ------------------------------------------------------ # # --- Forward step: AHP : Spike Frequency Adaptation --- # # ------------------------------------------------------ # Iws_ahp = self.Iw_ahp * spikes # 0 if no spike, Iw_ahp if spike iahp_inf = (Igain_ahp_clip / Itau_ahp_clip) * Iws_ahp # Calculate charge and discharge factors f_charge_ahp = 1.0 - jnp.exp(-t_pulse_ahp / tau_ahp) # Nrec f_discharge_ahp = jnp.exp(-self.dt / tau_ahp) # Nrec ## DISCHARGE in any case iahp = f_discharge_ahp * iahp ## CHARGE if spike occurs -- UNDERSAMPLED -- dt >> t_pulse iahp += f_charge_ahp * iahp_inf iahp = jnp.clip(iahp, self.Io) # Nrec # ------------------------------ # # --- Forward step: MEMBRANE --- # # ------------------------------ # ## Feedback _kappa_2 = jnp.power(kappa, 2.0) _kappa_prime = _kappa_2 / (kappa + 1.0) f_feedback = jnp.exp(_kappa_prime * (vmem / self.Ut)) # 4xNrec ## Leakage Ileak = Itau_mem_clip + iahp ## Injection Iin = isyn - Ileak + self.Idc Iin *= jnp.logical_not(timer_ref.astype(bool)).astype(jnp.float32) Iin = jnp.clip(Iin, self.Io) ## Steady state current imem_inf = (Igain_mem_clip / Itau_mem_clip) * (Iin - Ileak) ## Positive feedback Ifb = self.Io * f_feedback f_imem = ((Ifb) / (Ileak)) * (imem + Igain_mem_clip) ## Forward Euler Update del_imem = (imem / (tau_mem(Ileak) * (imem + Igain_mem_clip))) * ( imem_inf + f_imem - (imem * (1.0 + (iahp / Itau_mem_clip))) ) imem = imem + del_imem * self.dt imem = jnp.clip(imem, self.Io) ## Membrane Potential vmem = (self.Ut / kappa) * jnp.log(imem / self.Io) # ------------------------------ # # --- Spike Generation Logic --- # # ------------------------------ # ## Detect next spikes (with custom gradient) spikes = step_pwl(imem, self.Ispkthr, self.Io) ## Reset imem depending on spiking activity bool_spikes = jnp.clip(spikes, 0, 1) imem = (1.0 - bool_spikes) * imem + bool_spikes * self.Io ## Set the refractrory timer timer_ref -= self.dt timer_ref = jnp.clip(timer_ref, 0.0) timer_ref = (1.0 - bool_spikes) * timer_ref + bool_spikes * t_ref # ------------------------------ # # ----------- Output ----------- # # ------------------------------ # # ! IMPORTANT ! : SHOULD BE IN THE SAME ORDER WITH THE self.state() state = ( iahp, imem, isyn, rng_key, spikes, timer_ref, vmem, ) record_ts = (iahp, imem, isyn, spikes, vmem) return state, record_ts # --- Evolve over spiking inputs --- # ## Map over batches @jax.vmap def scan_time(state, data): return scan(forward, state, data) ## Scan state, record_ts = scan_time(initial_state, input_data) # --- Output --- # states = { "iahp": state[0], "imem": state[1], "isyn": state[2], "rng_key": state[3], "spikes": state[4], "timer_ref": state[5], "vmem": state[6], } record_dict = { "iahp": record_ts[0], "imem": record_ts[1], "isyn": record_ts[2], "spikes": record_ts[3], "vmem": record_ts[4], } return record_ts[3], states, record_dict
[docs] def as_graph(self) -> GraphHolder: """ as_graph returns a computational graph for the for the simulated Dynap-SE neurons :return: a ``GraphHolder`` object wrapping the DynapseNeurons graph. :rtype: GraphHolder """ # Get simulated current parameters kwargs = { __attr: np.array(self.__getattribute__(__attr)).flatten().tolist() for __attr in DynapseNeurons.current_attrs() } # Generate the main computational graph neurons = DynapseNeurons._factory( size_in=self.size_in, size_out=self.size_out, name=f"{type(self).__name__}_{self.name}_{id(self)}", computational_module=self, Iscale=float(np.array(self.Iscale).mean()), dt=self.dt, **kwargs, ) # - Include recurrent weights if present if np.array(self.w_rec).any(): # - Weights are connected over the existing input and output nodes w_rec_graph_auto_connected = LinearWeights( neurons.output_nodes, neurons.input_nodes, f"{type(self).__name__}_recurrent_{self.name}_{id(self)}", self, self.w_rec, ) return as_GraphHolder(neurons)