Source code for nn.modules.jax.exp_syn_jax

"""
An exponential synapse layer, with a Jax backend.
"""

# - Rockpool imports
from rockpool.nn.modules.jax.jax_module import JaxModule
from rockpool.parameters import Parameter, State, SimulationParameter

# - Other imports
import jax
import jax.numpy as np
from jax.lax import scan

import numpy as onp

from typing import Union, Optional, Tuple

from rockpool import typehints as rt

__all__ = ["ExpSynJax"]


[docs]class ExpSynJax(JaxModule): """ Exponential synapse module with a Jax backend This module implements a layer of exponential synapses, operating under the update equations .. math:: I_{syn} = I_{syn} + i(t) I_{syn} = I_{syn} * \exp(-dt / \tau) I_{syn} = I_{syn} + \sigma \zeta_t where :math:`i(t)` is the instantaneous input; :math:`\\tau` is the vector ``(N,)`` of time constants for each synapse in seconds; :math:`dt` is the update time-step in seconds; :math:`\\sigma` is the std. deviation after 1s of a Wiener process. """
[docs] def __init__( self, shape: Union[tuple, int], tau: Optional[rt.FloatVector] = None, noise_std: float = 0.0, dt: float = 1e-3, rng_key: Optional[rt.JaxRNGKey] = None, spiking_input: bool = True, spiking_output: bool = False, *args, **kwargs, ): """ Initialise an exponential synapse module Args: shape (Union[int, tuple]): The number of synapses in this module ``(N,)``. tau (Optional[np.ndarray]): Concrete initialisation data for the time constants of the synapses, in seconds. Default: 10 ms individual for all synapses. noise_std (float): The std. dev after 1s of noise added independently to each synapse dt (float): The timestep of this module, in seconds. Default: 1 ms. """ # - Call super-class initialisation super().__init__( shape=shape, spiking_input=spiking_input, spiking_output=spiking_output, *args, **kwargs, ) # - Seed RNG if rng_key is None: rng_key = jax.random.PRNGKey(onp.random.randint(0, 2**63)) _, rng_key = jax.random.split(np.array(rng_key, dtype=np.uint32)) # - Initialise state self.rng_key: Union[np.ndarray, State] = State( rng_key, init_func=lambda _: rng_key ) # - Record parameters self.dt: Union[float, SimulationParameter] = SimulationParameter(dt) """ (float) Time step for this module """ self.tau: rt.P_ndarray = Parameter( data=tau, shape=[(self.size_out,), ()], family="taus", init_func=lambda s: np.ones(*s) * 10e-3, cast_fn=np.array, ) """ (np.ndarray) Time constant of each synapse """ # - Initialise noise std dev self.noise_std: rt.P_float = SimulationParameter(noise_std, cast_fn=np.array) """ (float) Noise std. dev after 1 second """ self.isyn: Union[np.array, State] = State( shape=self.size_out, init_func=np.zeros, ) """ (torch.tensor) Synaptic current state for each synapse ``(1, N)`` """
[docs] def evolve( self, input_data: np.array, *args, **kwargs, ) -> Tuple[np.ndarray, dict, dict]: # - Get input shapes, add batch dimension if necessary if len(input_data.shape) == 2: input_data = np.expand_dims(input_data, 0) batches, num_timesteps, n_inputs = input_data.shape # - Pre-compute synapse decay beta beta = np.exp(-self.dt / self.tau) noise_zeta = self.noise_std * np.sqrt(self.dt) # - Define synaptic dynamics def forward(Isyn, input_t): Isyn += input_t Isyn *= beta return Isyn, Isyn # - Generate noise trace key1, subkey = jax.random.split(self.rng_key) noise_ts = noise_zeta * jax.random.normal( subkey, shape=(batches, num_timesteps, self.size_out) ) # - Replicate states isyn = np.broadcast_to(self.isyn, (batches, self.size_in)) # - Map over batches @jax.vmap def scan_time(isyn, input_data): return scan(forward, isyn, input_data) # - Scan over the input isyn, output = scan_time(isyn, input_data + noise_ts) # - Return output data and state return output, {"isyn": isyn[0]}, {}