nn.modules.LIFJaxο
- class nn.modules.LIFJax(*args, **kwargs)[source]ο
Bases:
JaxModule
A leaky integrate-and-fire spiking neuron model, with a Jax backend
This module implements the update equations:
\[ \begin{align}\begin{aligned}I_{syn} += S_{in}(t) + S_{rec} \cdot W_{rec}\\I_{syn} *= \exp(-dt / \tau_{syn})\\V_{mem} *= \exp(-dt / \tau_{mem})\\V_{mem} += I_{syn} + b + \sigma \zeta(t)\end{aligned}\end{align} \]where \(S_{in}(t)\) is a vector containing
1
(or a weighed spike) for each input channel that emits a spike at time \(t\); \(b\) is a \(N\) vector of bias currents for each neuron; \(\sigma\zeta(t)\) is a Wiener noise process with standard deviation \(\sigma\) after 1s; and \(\tau_{mem}\) and \(\tau_{syn}\) are the membrane and synaptic time constants, respectively. \(S_{rec}(t)\) is a vector containing1
for each neuron that emitted a spike in the last time-step. \(W_{rec}\) is a recurrent weight matrix, if recurrent weights are used. \(b\) is an optional bias current per neuron (default 0.).- On spiking:
When the membrane potential for neuron \(j\), \(V_{mem, j}\) exceeds the threshold voltage \(V_{thr}\), then the neuron emits a spike. The spiking neuron subtracts its own threshold on reset.
\[ \begin{align}\begin{aligned}V_{mem, j} > V_{thr} \rightarrow S_{rec,j} = 1\\V_{mem, j} = V_{mem, j} - V_{thr}\end{aligned}\end{align} \]Neurons therefore share a common resting potential of
0
, have individual firing thresholds, and perform subtractive reset of-V_{thr}
.Attributes overview
Class name of
self
The full name of this module (class plus module name)
The name of this module, or an empty string if
None
The shape of this module
(DEPRECATED) The output size of this module
The input size of this module
The output size of this module
If
True
, this module receives spiking input.If
True
, this module sends spiking output.(int) Number of input synapses per neuron
(Tensor) Recurrent weights
(Nout, Nin)
(np.ndarray) Membrane time constants
(Nout,)
or()
(np.ndarray) Synaptic time constants
(Nout,)
or()
(np.ndarray) Neuron bias currents
(Nout,)
or()
(np.ndarray) Firing threshold for each neuron
(Nout,)
or()
(float) Simulation time-step in seconds
(float) Noise injected on each neuron membrane per time-step
(np.ndarray) Spiking state of each neuron
(Nout,)
(np.ndarray) Synaptic current of each neuron
(Nout, Nsyn)
(np.ndarray) Membrane voltage of each neuron
(Nout,)
(float) Maximum number of events that can be produced in each time-step
Methods overview
__init__
(shape[,Β tau_mem,Β tau_syn,Β bias,Β ...])Instantiate an LIF module
as_graph
()Convert this module to a computational graph
attributes_named
(name)Search for attributes of this or submodules by time
evolve
(input_data[,Β record])- param input_data:
Input array of shape
(T, Nin)
to evolve over
modules
()Return a dictionary of all sub-modules of this module
parameters
([family])Return a nested dictionary of module and submodule Parameters
Reset all parameters in this module
Reset the state of this module
set_attributes
(new_attributes)Assign new attributes to this module and submodules
simulation_parameters
([family])Return a nested dictionary of module and submodule SimulationParameters
state
([family])Return a nested dictionary of module and submodule States
timed
([output_num,Β dt,Β add_events])Convert this module to a
TimedModule
Flatten this module tree for Jax
tree_unflatten
(aux_data,Β children)Unflatten a tree of modules from Jax to Rockpool
- __init__(shape: ~typing.Tuple | int, tau_mem: float | ~numpy.ndarray | ~torch.Tensor | ~jax._src.numpy.lax_numpy.array | None = None, tau_syn: float | ~numpy.ndarray | ~torch.Tensor | ~jax._src.numpy.lax_numpy.array | None = None, bias: float | ~numpy.ndarray | ~torch.Tensor | ~jax._src.numpy.lax_numpy.array | None = None, w_rec: float | ~numpy.ndarray | ~torch.Tensor | ~jax._src.numpy.lax_numpy.array | None = None, has_rec: bool = False, weight_init_func: ~typing.Callable[[~typing.Tuple], ~jax.Array] | None = <function kaiming>, threshold: float | ~numpy.ndarray | ~torch.Tensor | ~jax._src.numpy.lax_numpy.array | None = None, noise_std: float = 0.0, max_spikes_per_dt: float | ~rockpool.parameters.ParameterBase = 65536.0, dt: float = 0.001, rng_key: ~typing.Any | None = None, spiking_input: bool = False, spiking_output: bool = True, *args, **kwargs)[source]ο
Instantiate an LIF module
- Parameters:
shape (tuple) β Either a single dimension
(Nout,)
, which defines a feed-forward layer of LIF modules with equal amounts of synapses and neurons, or two dimensions(Nin, Nout)
, which defines a layer ofNin
synapses andNout
LIF neurons.tau_mem (Optional[np.ndarray]) β An optional array with concrete initialisation data for the membrane time constants. If not provided, 20ms will be used by default.
tau_syn (Optional[np.ndarray]) β An optional array with concrete initialisation data for the synaptic time constants. If not provided, 20ms will be used by default.
bias (Optional[np.ndarray]) β An optional array with concrete initialisation data for the neuron bias currents. If not provided, 0.0 will be used by default.
w_rec (Optional[np.ndarray]) β If the module is initialised in recurrent mode, you can provide a concrete initialisation for the recurrent weights, which must be a square matrix with shape
(Nout, Nin)
.has_rec (bool) β If
True
, module provides a recurrent weight matrix. Default:False
, no recurrent connectivity.weight_init_func (Optional[Callable[[Tuple], np.ndarray]) β The initialisation function to use when generating weights. Default:
None
(Kaiming initialisation)threshold (FloatVector) β An optional array specifying the firing threshold of each neuron. If not provided,
1.
will be used by default.noise_std (float) β The std. dev. after 1s of the noise added to membrane state variables. Default:
0.0
(no noise).max_spikes_per_dt (float) β The maximum number of events that will be produced in a single time-step. Default:
2**16
.dt (float) β The time step for the forward-Euler ODE solver. Default: 1ms
rng_key (Optional[Any]) β The Jax RNG seed to use on initialisation. By default, a new seed is generated.
- _abc_impl = <_abc._abc_data object>ο
- _auto_batch(data: Array, states: Tuple = (), target_shapes: Tuple | None = None) Tuple[Array, Tuple[Array]] ο
Automatically replicate states over batches and verify input dimensions
- Usage:
>>> data, (state0, state1, state2) = self._auto_batch(data, (self.state0, self.state1, self.state2))
This will verify that
data
has the correct final dimension (i.e.self.size_in
). Ifdata
has only two dimensions(T, Nin)
, then it will be augmented to(1, T, Nin)
. The individual states will be replicated out from shape(a, b, c, ...)
to(n_batches, a, b, c, ...)
and returned.
- Parameters:
data (np.ndarray) β Input data tensor. Either
(batches, T, Nin)
or(T, Nin)
states (Tuple) β Tuple of state variables. Each will be replicated out over batches by prepending a batch dimension
- Returns:
(np.ndarray, Tuple[np.ndarray]) data, states
- _force_set_attributesο
(bool) If
True
, do not sanity-check attributes when setting.
- _get_attribute_family(type_name: str, family: Tuple | List | str | None = None) dict ο
Search for attributes of this module and submodules that match a given family
This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string:
"weights"
for weights;"taus"
for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six monthβs time).- Parameters:
type_name (str) β The class of parameters to search for. Must be one of
["Parameter", "SimulationParameter", "State"]
or another future subclass ofParameterBase
family (Union[str, Tuple[str]]) β A string or list or tuple of strings, that define one or more attribute families to search for
- Returns:
A nested dictionary of attributes that match the provided
type_name
andfamily
- Return type:
dict
- _get_attribute_registry() Tuple[Dict, Dict] ο
Return or initialise the attribute registry for this module
- Returns:
registered_attributes, registered_modules
- Return type:
(tuple)
- _has_registered_attribute(name: str) bool ο
Check if the module has a registered attribute
- Parameters:
name (str) β The name of the attribute to check
- Returns:
True
if the attributename
is in the attribute registry,False
otherwise.- Return type:
bool
- _in_Module_initο
(bool) If exists and
True
, indicates that the module is in the__init__
chain.
- _name: str | Noneο
Name of this module, if assigned
- _register_attribute(name: str, val: ParameterBase)ο
Record an attribute in the attribute registry
- Parameters:
name (str) β The name of the attribute to register
val (ParameterBase) β The
ParameterBase
subclass object to register. e.g.Parameter
,SimulationParameter
orState
.
- _register_module(name: str, mod)ο
Add a submodule to the module registry
- Parameters:
name (str) β The name of the submodule, extracted from the assigned attribute name
mod (JaxModule) β The submodule to register
- _reset_attribute(name: str) ModuleBase ο
Reset an attribute to its initialisation value
- Parameters:
name (str) β The name of the attribute to reset
- Returns:
For compatibility with the functional API
- Return type:
self (
Module
)
- _shapeο
The shape of this module
- _spiking_input: boolο
Whether this module receives spiking input
- _spiking_output: boolο
Whether this module produces spiking output
- _submodulenames: List[str]ο
Registry of sub-module names
- _wrap_recorded_state(state_dict: dict, t_start: float = 0.0) dict [source]ο
Convert a recorded dictionary to a
TimeSeries
representationThis method is optional, and is provided to make the
timed()
conversion to aTimedModule
work better. You should override this method in your customModule
, to wrap each element of your recorded state dictionary as aTimeSeries
- Parameters:
state_dict (dict) β A recorded state dictionary as returned by
evolve()
t_start (float) β The initial time of the recorded state, to use as the starting point of the time series
- Returns:
The mapped recorded state dictionary, wrapped as
TimeSeries
objects- Return type:
Dict[str, TimeSeries]
- as_graph() GraphModuleBase [source]ο
Convert this module to a computational graph
- Returns:
The computational graph corresponding to this module
- Return type:
- Raises:
NotImplementedError β If
as_graph()
is not implemented for this subclass
- attributes_named(name: Tuple[str] | List[str] | str) dict ο
Search for attributes of this or submodules by time
- Parameters:
name (Union[str, Tuple[str]) β The name of the attribute to search for
- Returns:
A nested dictionary of attributes that match
name
- Return type:
dict
- bias: P_ndarrayο
(np.ndarray) Neuron bias currents
(Nout,)
or()
- property class_name: strο
Class name of
self
- Type:
str
- dt: P_floatο
(float) Simulation time-step in seconds
- evolve(input_data: Array, record: bool = False) Tuple[Array, dict, dict] [source]ο
- Parameters:
input_data (np.ndarray) β Input array of shape
(T, Nin)
to evolve overrecord (bool) β If
True
,
- Returns:
output, new_state, record_state
output
is an array with shape(T, Nout)
containing the output data produced by this module.new_state
is a dictionary containing the updated module state following evolution.record_state
will be a dictionary containing the recorded state variables for this evolution, if therecord
argument isTrue
.- Return type:
(np.ndarray, dict, dict)
- property full_name: strο
The full name of this module (class plus module name)
- Type:
str
- isyn: P_ndarrayο
(np.ndarray) Synaptic current of each neuron
(Nout, Nsyn)
- max_spikes_per_dt: P_floatο
(float) Maximum number of events that can be produced in each time-step
- modules() Dict ο
Return a dictionary of all sub-modules of this module
- Returns:
A dictionary containing all sub-modules. Each item will be named with the sub-module name.
- Return type:
dict
- n_synapsesο
(int) Number of input synapses per neuron
- property name: strο
The name of this module, or an empty string if
None
- Type:
str
- noise_std: P_floatο
(float) Noise injected on each neuron membrane per time-step
- parameters(family: Tuple | List | str | None = None) Dict ο
Return a nested dictionary of module and submodule Parameters
Use this method to inspect the Parameters from this and all submodules. The optional argument
family
allows you to search for Parameters in a particular family β for example"weights"
for all weights of this module and nested submodules.Although the
family
argument is an arbitrary string, reasonable choises are"weights"
,"taus"
for time constants,"biases"
for biasesβ¦Examples
Obtain a dictionary of all Parameters for this module (including submodules):
>>> mod.parameters() dict{ ... }
Obtain a dictionary of Parameters from a particular family:
>>> mod.parameters("weights") dict{ ... }
- Parameters:
family (str) β The family of Parameters to search for. Default:
None
; return all parameters.- Returns:
A nested dictionary of Parameters of this module and all submodules
- Return type:
dict
- reset_parameters()ο
Reset all parameters in this module
- Returns:
The updated module is returned for compatibility with the functional API
- Return type:
- reset_state() JaxModule ο
Reset the state of this module
- Returns:
The updated module is returned for compatibility with the functional API
- Return type:
- set_attributes(new_attributes: Iterable | MutableMapping | Mapping) JaxModule ο
Assign new attributes to this module and submodules
- Parameters:
new_attributes (Tree) β The tree of new attributes to assign to this module tree
- Return type:
- property shape: tupleο
The shape of this module
- Type:
tuple
- simulation_parameters(family: Tuple | List | str | None = None) Dict ο
Return a nested dictionary of module and submodule SimulationParameters
Use this method to inspect the SimulationParameters from this and all submodules. The optional argument
family
allows you to search for SimulationParameters in a particular family.Examples
Obtain a dictionary of all SimulationParameters for this module (including submodules):
>>> mod.simulation_parameters() dict{ ... }
- Parameters:
family (str) β The family of SimulationParameters to search for. Default:
None
; return all SimulationParameter attributes.- Returns:
A nested dictionary of SimulationParameters of this module and all submodules
- Return type:
dict
- property size: intο
(DEPRECATED) The output size of this module
- Type:
int
- property size_in: intο
The input size of this module
- Type:
int
- property size_out: intο
The output size of this module
- Type:
int
- spikes: P_ndarrayο
(np.ndarray) Spiking state of each neuron
(Nout,)
- property spiking_input: boolο
If
True
, this module receives spiking input. IfFalse
, this module expects continuous input.- Type:
bool
- property spiking_outputο
If
True
, this module sends spiking output. IfFalse
, this module sends continuous output.- Type:
bool
- state(family: Tuple | List | str | None = None) Dict ο
Return a nested dictionary of module and submodule States
Use this method to inspect the States from this and all submodules. The optional argument
family
allows you to search for States in a particular family.Examples
Obtain a dictionary of all States for this module (including submodules):
>>> mod.state() dict{ ... }
- Parameters:
family (str) β The family of States to search for. Default:
None
; return all State attributes.- Returns:
A nested dictionary of States of this module and all submodules
- Return type:
dict
- tau_mem: P_ndarrayο
(np.ndarray) Membrane time constants
(Nout,)
or()
- tau_syn: P_ndarrayο
(np.ndarray) Synaptic time constants
(Nout,)
or()
- threshold: P_ndarrayο
(np.ndarray) Firing threshold for each neuron
(Nout,)
or()
- timed(output_num: int = 0, dt: float | None = None, add_events: bool = False)ο
Convert this module to a
TimedModule
- Parameters:
output_num (int) β Specify which output of the module to take, if the module returns multiple output series. Default:
0
, take the first (or only) output.dt (float) β Used to provide a time-step for this module, if the module does not already have one. If
self
already defines a time-step, thenself.dt
will be used. Default:None
add_events (bool) β Iff
True
, theTimedModule
will add events occurring on a single timestep on input and output. Default:False
, donβt add time steps.
Returns:
TimedModule
: A timed module that wraps this module
- tree_flatten() Tuple[tuple, tuple] ο
Flatten this module tree for Jax
- classmethod tree_unflatten(aux_data, children)ο
Unflatten a tree of modules from Jax to Rockpool
- vmem: P_ndarrayο
(np.ndarray) Membrane voltage of each neuron
(Nout,)
- w_rec: P_ndarrayο
(Tensor) Recurrent weights
(Nout, Nin)