nn.modules.LIFJax

class nn.modules.LIFJax(*args, **kwargs)[source]

Bases: rockpool.nn.modules.jax.jax_module.JaxModule

A leaky integrate-and-fire spiking neuron model, with a Jax backend

This module implements the update equations:

\[I_{syn} += S_{in}(t) + S_{rec} \cdot W_{rec} I_{syn} *= \exp(-dt / au_{syn}) V_{mem} *= \exp(-dt / au_{mem}) V_{mem} += I_{syn} + b + \sigma \zeta(t)\]

where \(S_{in}(t)\) is a vector containing 1 (or a weighed spike) for each input channel that emits a spike at time \(t\); \(b\) is a \(N\) vector of bias currents for each neuron; \(\sigma\zeta(t)\) is a Wiener noise process with standard deviation \(\sigma\) after 1s; and \(\tau_{mem}\) and \(\tau_{syn}\) are the membrane and synaptic time constants, respectively. \(S_{rec}(t)\) is a vector containing 1 for each neuron that emitted a spike in the last time-step. \(W_{rec}\) is a recurrent weight matrix, if recurrent weights are used. \(b\) is an optional bias current per neuron (default 0.).

On spiking

When the membrane potential for neuron \(j\), \(V_{mem, j}\) exceeds the threshold voltage \(V_{thr}\), then the neuron emits a spike. The spiking neuron subtracts its own threshold on reset.

\[ \begin{align}\begin{aligned}V_{mem, j} > V_{thr} \rightarrow S_{rec,j} = 1\\V_{mem, j} = V_{mem, j} - V_{thr}\end{aligned}\end{align} \]

Neurons therefore share a common resting potential of 0, have individual firing thresholds, and perform subtractive reset of -V_{thr}.

__init__(shape: typing.Union[typing.Tuple, int], tau_mem: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, tau_syn: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, bias: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, w_rec: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, has_rec: bool = False, weight_init_func: typing.Optional[typing.Callable[[typing.Tuple], jax._src.numpy.lax_numpy.ndarray]] = <function kaiming>, threshold: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, noise_std: float = 0.0, max_spikes_per_dt: typing.Union[int, rockpool.parameters.ParameterBase] = inf, dt: float = 0.001, rng_key: typing.Optional[typing.Any] = None, spiking_input: bool = False, spiking_output: bool = True, *args, **kwargs)[source]

Instantiate an LIF module

Parameters
  • shape (tuple) – Either a single dimension (Nout,), which defines a feed-forward layer of LIF modules with equal amounts of synapses and neurons, or two dimensions (Nin, Nout), which defines a layer of Nin synapses and Nout LIF neurons.

  • tau_mem (Optional[np.ndarray]) – An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default.

  • tau_syn (Optional[np.ndarray]) – An optional array with concrete initialisation data for the synaptic time constants. If not provided, 20ms will be used by default.

  • bias (Optional[np.ndarray]) – An optional array with concrete initialisation data for the neuron bias currents. If not provided, 0.0 will be used by default.

  • w_rec (Optional[np.ndarray]) – If the module is initialised in recurrent mode, you can provide a concrete initialisation for the recurrent weights, which must be a square matrix with shape (Nout, Nin).

  • has_rec (bool) – If True, module provides a recurrent weight matrix. Default: False, no recurrent connectivity.

  • weight_init_func (Optional[Callable[[Tuple], np.ndarray]) – The initialisation function to use when generating weights. Default: None (Kaiming initialisation)

  • threshold (FloatVector) – An optional array specifying the firing threshold of each neuron. If not provided, 1. will be used by default.

  • noise_std (float) – The std. dev. after 1s of the noise added to membrane state variables. Default: 0.0 (no noise).

  • max_spikes_per_dt (int) – The maximum number of events that will be produced in a single time-step. Default: np.inf; do not clamp spiking.

  • dt (float) – The time step for the forward-Euler ODE solver. Default: 1ms

  • rng_key (Optional[Any]) – The Jax RNG seed to use on initialisation. By default, a new seed is generated.

Attributes overview

class_name

Class name of self

full_name

The full name of this module (class plus module name)

name

The name of this module, or an empty string if None

shape

The shape of this module

size

(DEPRECATED) The output size of this module

size_in

The input size of this module

size_out

The output size of this module

spiking_input

If True, this module receives spiking input.

spiking_output

If True, this module sends spiking output.

n_synapses

(int) Number of input synapses per neuron

w_rec

(Tensor) Recurrent weights (Nout, Nin)

tau_mem

(np.ndarray) Membrane time constants (Nout,) or ()

tau_syn

(np.ndarray) Synaptic time constants (Nout,) or ()

bias

(np.ndarray) Neuron bias currents (Nout,) or ()

threshold

(np.ndarray) Firing threshold for each neuron (Nout,) or ()

dt

(float) Simulation time-step in seconds

noise_std

(float) Noise injected on each neuron membrane per time-step

spikes

(np.ndarray) Spiking state of each neuron (Nout,)

isyn

(np.ndarray) Synaptic current of each neuron (Nout, Nsyn)

vmem

(np.ndarray) Membrane voltage of each neuron (Nout,)

max_spikes_per_dt

(int) Maximum number of events that can be produced in each time-step

Methods overview

__init__(shape[,Β tau_mem,Β tau_syn,Β bias,Β ...])

Instantiate an LIF module

as_graph()

Convert this module to a computational graph

attributes_named(name)

Search for attributes of this or submodules by time

evolve(input_data[,Β record])

param input_data

Input array of shape (T, Nin) to evolve over

modules()

Return a dictionary of all sub-modules of this module

parameters([family])

Return a nested dictionary of module and submodule Parameters

reset_parameters()

Reset all parameters in this module

reset_state()

Reset the state of this module

set_attributes(new_attributes)

Assign new attributes to this module and submodules

simulation_parameters([family])

Return a nested dictionary of module and submodule SimulationParameters

state([family])

Return a nested dictionary of module and submodule States

timed([output_num,Β dt,Β add_events])

Convert this module to a TimedModule

tree_flatten()

Flatten this module tree for Jax

tree_unflatten(aux_data,Β children)

Unflatten a tree of modules from Jax to Rockpool

__init__(shape: typing.Union[typing.Tuple, int], tau_mem: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, tau_syn: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, bias: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, w_rec: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, has_rec: bool = False, weight_init_func: typing.Optional[typing.Callable[[typing.Tuple], jax._src.numpy.lax_numpy.ndarray]] = <function kaiming>, threshold: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, noise_std: float = 0.0, max_spikes_per_dt: typing.Union[int, rockpool.parameters.ParameterBase] = inf, dt: float = 0.001, rng_key: typing.Optional[typing.Any] = None, spiking_input: bool = False, spiking_output: bool = True, *args, **kwargs)[source]

Instantiate an LIF module

Parameters
  • shape (tuple) – Either a single dimension (Nout,), which defines a feed-forward layer of LIF modules with equal amounts of synapses and neurons, or two dimensions (Nin, Nout), which defines a layer of Nin synapses and Nout LIF neurons.

  • tau_mem (Optional[np.ndarray]) – An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default.

  • tau_syn (Optional[np.ndarray]) – An optional array with concrete initialisation data for the synaptic time constants. If not provided, 20ms will be used by default.

  • bias (Optional[np.ndarray]) – An optional array with concrete initialisation data for the neuron bias currents. If not provided, 0.0 will be used by default.

  • w_rec (Optional[np.ndarray]) – If the module is initialised in recurrent mode, you can provide a concrete initialisation for the recurrent weights, which must be a square matrix with shape (Nout, Nin).

  • has_rec (bool) – If True, module provides a recurrent weight matrix. Default: False, no recurrent connectivity.

  • weight_init_func (Optional[Callable[[Tuple], np.ndarray]) – The initialisation function to use when generating weights. Default: None (Kaiming initialisation)

  • threshold (FloatVector) – An optional array specifying the firing threshold of each neuron. If not provided, 1. will be used by default.

  • noise_std (float) – The std. dev. after 1s of the noise added to membrane state variables. Default: 0.0 (no noise).

  • max_spikes_per_dt (int) – The maximum number of events that will be produced in a single time-step. Default: np.inf; do not clamp spiking.

  • dt (float) – The time step for the forward-Euler ODE solver. Default: 1ms

  • rng_key (Optional[Any]) – The Jax RNG seed to use on initialisation. By default, a new seed is generated.

_abc_impl = <_abc_data object>
_auto_batch(data: jax._src.numpy.lax_numpy.ndarray, states: typing.Tuple = (), target_shapes: typing.Optional[typing.Tuple] = None) -> (<class 'jax._src.numpy.lax_numpy.ndarray'>, typing.Tuple[jax._src.numpy.lax_numpy.ndarray])

Automatically replicate states over batches and verify input dimensions

Usage:
>>> data, (state0, state1, state2) = self._auto_batch(data, (self.state0, self.state1, self.state2))

This will verify that data has the correct final dimension (i.e. self.size_in). If data has only two dimensions (T, Nin), then it will be augmented to (1, T, Nin). The individual states will be replicated out from shape (a, b, c, ...) to (n_batches, a, b, c, ...) and returned.

Parameters
  • data (np.ndarray) – Input data tensor. Either (batches, T, Nin) or (T, Nin)

  • states (Tuple) – Tuple of state variables. Each will be replicated out over batches by prepending a batch dimension

Returns

(np.ndarray, Tuple[np.ndarray]) data, states

_force_set_attributes

(bool) If True, do not sanity-check attributes when setting.

_get_attribute_family(type_name: str, family: Optional[Union[Tuple, List, str]] = None) dict

Search for attributes of this module and submodules that match a given family

This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string: "weights" for weights; "taus" for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six month’s time).

Parameters
  • type_name (str) – The class of parameters to search for. Must be one of ["Parameter", "SimulationParameter", "State"] or another future subclass of ParameterBase

  • family (Union[str, Tuple[str]]) – A string or list or tuple of strings, that define one or more attribute families to search for

Returns

A nested dictionary of attributes that match the provided type_name and family

Return type

dict

_get_attribute_registry() Tuple[Dict, Dict]

Return or initialise the attribute registry for this module

Returns

registered_attributes, registered_modules

Return type

(tuple)

_has_registered_attribute(name: str) bool

Check if the module has a registered attribute

Parameters

name (str) – The name of the attribute to check

Returns

True if the attribute name is in the attribute registry, False otherwise.

Return type

bool

_in_Module_init

(bool) If exists and True, indicates that the module is in the __init__ chain.

_name: Optional[str]

Name of this module, if assigned

_register_attribute(name: str, val: rockpool.parameters.ParameterBase)

Record an attribute in the attribute registry

Parameters
  • name (str) – The name of the attribute to register

  • val (ParameterBase) – The ParameterBase subclass object to register. e.g. Parameter, SimulationParameter or State.

_register_module(name: str, mod)

Add a submodule to the module registry

Parameters
  • name (str) – The name of the submodule, extracted from the assigned attribute name

  • mod (JaxModule) – The submodule to register

_reset_attribute(name: str) rockpool.nn.modules.module.ModuleBase

Reset an attribute to its initialisation value

Parameters

name (str) – The name of the attribute to reset

Returns

For compatibility with the functional API

Return type

self (Module)

_rockpool_pytree_registry = []

The internal registry of registered JaxModule s

_shape

The shape of this module

_spiking_input: bool

Whether this module receives spiking input

_spiking_output: bool

Whether this module produces spiking output

_submodulenames: List[str]

Registry of sub-module names

_wrap_recorded_state(state_dict: dict, t_start: float = 0.0) dict[source]

Convert a recorded dictionary to a TimeSeries representation

This method is optional, and is provided to make the timed() conversion to a TimedModule work better. You should override this method in your custom Module, to wrap each element of your recorded state dictionary as a TimeSeries

Parameters
  • state_dict (dict) – A recorded state dictionary as returned by evolve()

  • t_start (float) – The initial time of the recorded state, to use as the starting point of the time series

Returns

The mapped recorded state dictionary, wrapped as TimeSeries objects

Return type

Dict[str, TimeSeries]

as_graph() rockpool.graph.graph_base.GraphModuleBase[source]

Convert this module to a computational graph

Returns

The computational graph corresponding to this module

Return type

GraphModuleBase

Raises

NotImplementedError – If as_graph() is not implemented for this subclass

attributes_named(name: Union[Tuple[str], List[str], str]) dict

Search for attributes of this or submodules by time

Parameters

name (Union[str, Tuple[str]) – The name of the attribute to search for

Returns

A nested dictionary of attributes that match name

Return type

dict

bias: P_ndarray

(np.ndarray) Neuron bias currents (Nout,) or ()

property class_name: str

Class name of self

Type

str

dt: P_float

(float) Simulation time-step in seconds

evolve(input_data: jax._src.numpy.lax_numpy.ndarray, record: bool = False) Tuple[jax._src.numpy.lax_numpy.ndarray, dict, dict][source]
Parameters
  • input_data (np.ndarray) – Input array of shape (T, Nin) to evolve over

  • record (bool) – If True,

Returns

output, new_state, record_state output is an array with shape (T, Nout) containing the output data produced by this module. new_state is a dictionary containing the updated module state following evolution. record_state will be a dictionary containing the recorded state variables for this evolution, if the record argument is True.

Return type

(np.ndarray, dict, dict)

property full_name: str

The full name of this module (class plus module name)

Type

str

isyn: P_ndarray

(np.ndarray) Synaptic current of each neuron (Nout, Nsyn)

max_spikes_per_dt: P_int

(int) Maximum number of events that can be produced in each time-step

modules() Dict

Return a dictionary of all sub-modules of this module

Returns

A dictionary containing all sub-modules. Each item will be named with the sub-module name.

Return type

dict

n_synapses

(int) Number of input synapses per neuron

property name: str

The name of this module, or an empty string if None

Type

str

noise_std: P_float

(float) Noise injected on each neuron membrane per time-step

parameters(family: Optional[Union[Tuple, List, str]] = None) Dict

Return a nested dictionary of module and submodule Parameters

Use this method to inspect the Parameters from this and all submodules. The optional argument family allows you to search for Parameters in a particular family β€” for example "weights" for all weights of this module and nested submodules.

Although the family argument is an arbitrary string, reasonable choises are "weights", "taus" for time constants, "biases" for biases…

Examples

Obtain a dictionary of all Parameters for this module (including submodules):

>>> mod.parameters()
dict{ ... }

Obtain a dictionary of Parameters from a particular family:

>>> mod.parameters("weights")
dict{ ... }
Parameters

family (str) – The family of Parameters to search for. Default: None; return all parameters.

Returns

A nested dictionary of Parameters of this module and all submodules

Return type

dict

reset_parameters()

Reset all parameters in this module

Returns

The updated module is returned for compatibility with the functional API

Return type

Module

reset_state() rockpool.nn.modules.jax.jax_module.JaxModule

Reset the state of this module

Returns

The updated module is returned for compatibility with the functional API

Return type

Module

set_attributes(new_attributes: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping]) rockpool.nn.modules.jax.jax_module.JaxModule

Assign new attributes to this module and submodules

Parameters

new_attributes (Tree) – The tree of new attributes to assign to this module tree

Return type

JaxModule

property shape: tuple

The shape of this module

Type

tuple

simulation_parameters(family: Optional[Union[Tuple, List, str]] = None) Dict

Return a nested dictionary of module and submodule SimulationParameters

Use this method to inspect the SimulationParameters from this and all submodules. The optional argument family allows you to search for SimulationParameters in a particular family.

Examples

Obtain a dictionary of all SimulationParameters for this module (including submodules):

>>> mod.simulation_parameters()
dict{ ... }
Parameters

family (str) – The family of SimulationParameters to search for. Default: None; return all SimulationParameter attributes.

Returns

A nested dictionary of SimulationParameters of this module and all submodules

Return type

dict

property size: int

(DEPRECATED) The output size of this module

Type

int

property size_in: int

The input size of this module

Type

int

property size_out: int

The output size of this module

Type

int

spikes: P_ndarray

(np.ndarray) Spiking state of each neuron (Nout,)

property spiking_input: bool

If True, this module receives spiking input. If False, this module expects continuous input.

Type

bool

property spiking_output

If True, this module sends spiking output. If False, this module sends continuous output.

Type

bool

state(family: Optional[Union[Tuple, List, str]] = None) Dict

Return a nested dictionary of module and submodule States

Use this method to inspect the States from this and all submodules. The optional argument family allows you to search for States in a particular family.

Examples

Obtain a dictionary of all States for this module (including submodules):

>>> mod.state()
dict{ ... }
Parameters

family (str) – The family of States to search for. Default: None; return all State attributes.

Returns

A nested dictionary of States of this module and all submodules

Return type

dict

tau_mem: P_ndarray

(np.ndarray) Membrane time constants (Nout,) or ()

tau_syn: P_ndarray

(np.ndarray) Synaptic time constants (Nout,) or ()

threshold: P_ndarray

(np.ndarray) Firing threshold for each neuron (Nout,) or ()

timed(output_num: int = 0, dt: Optional[float] = None, add_events: bool = False)

Convert this module to a TimedModule

Parameters
  • output_num (int) – Specify which output of the module to take, if the module returns multiple output series. Default: 0, take the first (or only) output.

  • dt (float) – Used to provide a time-step for this module, if the module does not already have one. If self already defines a time-step, then self.dt will be used. Default: None

  • add_events (bool) – Iff True, the TimedModule will add events occurring on a single timestep on input and output. Default: False, don’t add time steps.

Returns: TimedModule: A timed module that wraps this module

tree_flatten() Tuple[tuple, tuple]

Flatten this module tree for Jax

classmethod tree_unflatten(aux_data, children)

Unflatten a tree of modules from Jax to Rockpool

vmem: P_ndarray

(np.ndarray) Membrane voltage of each neuron (Nout,)

w_rec: P_ndarray

(Tensor) Recurrent weights (Nout, Nin)