nn.modules.RateJaxο
- class nn.modules.RateJax(*args, **kwargs)[source]ο
Bases:
rockpool.nn.modules.jax.jax_module.JaxModule
Encapsulates a population of rate neurons, supporting feed-forward and recurrent modules, with a Jax backend
Examples
Instantiate a feed-forward module with 8 neurons:
>>> mod = RateJax(8,) RateEulerJax 'None' with shape (8,)
Instantiate a recurrent module with 12 neurons:
>>> mod_rec = RateJax(12, has_rec = True) RateEulerJax 'None' with shape (12,)
Instantiate a feed-forward module with defined time constants:
>>> mod = RateJax(7, tau = np.arange(7,) * 10e-3) RateEulerJax 'None' with shape (7,)
This module implements the update equations:
\[ \begin{align}\begin{aligned}\dot{X} = -X + i(t) + W_{rec} H(X) + bias + \sigma \zeta_t X = X + \dot{x} * dt / au\\H(x, t) = relu(x, t) = (x - t) * ((x - t) > 0)\end{aligned}\end{align} \]Attributes overview
Class name of
self
The full name of this module (class plus module name)
The name of this module, or an empty string if
None
The shape of this module
(DEPRECATED) The output size of this module
The input size of this module
The output size of this module
If
True
, this module receives spiking input.If
True
, this module sends spiking output.The Jax PRNG key for this module
A vector
(N,)
of the internal state of each unitThe recurrent weight matrix
(N, N)
for this moduleThe vector
(N,)
of time constants \(\tau\) for each unitThe vector
(N,)
of bias currents for each unit(Tensor) Unit thresholds
(Nout,)
or()
The Euler solver time step for this module
The std.
(Callable) Activation function
Methods overview
__init__
(shape[,Β tau,Β bias,Β threshold,Β ...])Instantiate a non-spiking rate module, either feed-forward or recurrent.
as_graph
()Convert this module to a computational graph
attributes_named
(name)Search for attributes of this or submodules by time
evolve
(input_data[,Β record])Evolve the state of this module over input data
modules
()Return a dictionary of all sub-modules of this module
parameters
([family])Return a nested dictionary of module and submodule Parameters
Reset all parameters in this module
Reset the state of this module
set_attributes
(new_attributes)Assign new attributes to this module and submodules
simulation_parameters
([family])Return a nested dictionary of module and submodule SimulationParameters
state
([family])Return a nested dictionary of module and submodule States
timed
([output_num,Β dt,Β add_events])Convert this module to a
TimedModule
Flatten this module tree for Jax
tree_unflatten
(aux_data,Β children)Unflatten a tree of modules from Jax to Rockpool
- __init__(shape: typing.Union[int, typing.Tuple[jax._src.numpy.lax_numpy.ndarray]], tau: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, bias: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, threshold: typing.Optional[typing.Union[float, numpy.ndarray, torch.Tensor]] = None, w_rec: typing.Optional[jax._src.numpy.lax_numpy.ndarray] = None, weight_init_func: typing.Callable = <function unit_eigs>, has_rec: bool = False, activation_func: typing.Union[str, typing.Callable] = <function H_ReLU>, noise_std: float = 0.0, dt: float = 0.001, rng_key: typing.Optional[int] = None, *args: list, **kwargs: dict)[source]ο
Instantiate a non-spiking rate module, either feed-forward or recurrent.
- Parameters
shape (Tuple[np.ndarray]) β A tuple containing the shape of this module. If one dimension is provided
(N,)
, it will define the number of neurons in a feed-forward layer. If two dimensions are provided, a recurrent layer will be defined. In that case the two dimensions must be identical(N, N)
.tau (float) β A scalar or vector defining the initialisation time constants for the module. If a vector is provided, it must match the output size of the module. Default:
20ms
bias (float) β A scalar or vector defining the initialisation bias values for the module. If a vector is provided, it must match the output size of the module. Default:
0.
w_rec (np.ndarray) β An optional matrix defining the initialisation recurrent weights for the module. Default:
Normal / sqrt(N)
has_rec (bool) β Iff
True
, the module operates in recurrent mode. Default:False
, operate in feed-forward mode.weight_init_func (Callable) β A function used to initialise the recurrent weights, if used. Default:
unit_eigs()
; initialise such that recurrent feedback has eigenvalues distributed within the unit circle.activation_func (Callable) β The activation function of the neurons. This can be provided as a string
['ReLU', 'sigmoid', 'tanh']
, or as a function that accepts a vector of neural states and returns the vector of output activations. This function must usejax.numpy
math functions, and notnumpy
math functions. Default:'ReLU'
.dt (float) β The Euler solver time-step. Default:
1e-3
noise_std (float) β The std. dev. of normally-distributed noise added to the neural state at each time step. Default:
0.
rng_key (Any) β A Jax PRNG key to initialise the module with. Default: not provided, the module PRNG will be initialised with a random number.
*args β Additional positional arguments
**kwargs β Additional keyword arguments
- _abc_impl = <_abc_data object>ο
- _auto_batch(data: jax._src.numpy.lax_numpy.ndarray, states: Tuple = (), target_shapes: Optional[Tuple] = None) Tuple[jax._src.numpy.lax_numpy.ndarray, Tuple[jax._src.numpy.lax_numpy.ndarray]] ο
Automatically replicate states over batches and verify input dimensions
- Usage:
>>> data, (state0, state1, state2) = self._auto_batch(data, (self.state0, self.state1, self.state2))
This will verify that
data
has the correct final dimension (i.e.self.size_in
). Ifdata
has only two dimensions(T, Nin)
, then it will be augmented to(1, T, Nin)
. The individual states will be replicated out from shape(a, b, c, ...)
to(n_batches, a, b, c, ...)
and returned.
- Parameters
data (np.ndarray) β Input data tensor. Either
(batches, T, Nin)
or(T, Nin)
states (Tuple) β Tuple of state variables. Each will be replicated out over batches by prepending a batch dimension
- Returns
(np.ndarray, Tuple[np.ndarray]) data, states
- _force_set_attributesο
(bool) If
True
, do not sanity-check attributes when setting.
- _get_attribute_family(type_name: str, family: Optional[Union[Tuple, List, str]] = None) dict ο
Search for attributes of this module and submodules that match a given family
This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string:
"weights"
for weights;"taus"
for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six monthβs time).- Parameters
type_name (str) β The class of parameters to search for. Must be one of
["Parameter", "SimulationParameter", "State"]
or another future subclass ofParameterBase
family (Union[str, Tuple[str]]) β A string or list or tuple of strings, that define one or more attribute families to search for
- Returns
A nested dictionary of attributes that match the provided
type_name
andfamily
- Return type
dict
- _get_attribute_registry() Tuple[Dict, Dict] ο
Return or initialise the attribute registry for this module
- Returns
registered_attributes, registered_modules
- Return type
(tuple)
- _has_registered_attribute(name: str) bool ο
Check if the module has a registered attribute
- Parameters
name (str) β The name of the attribute to check
- Returns
True
if the attributename
is in the attribute registry,False
otherwise.- Return type
bool
- _in_Module_initο
(bool) If exists and
True
, indicates that the module is in the__init__
chain.
- _name: Optional[str]ο
Name of this module, if assigned
- _register_attribute(name: str, val: rockpool.parameters.ParameterBase)ο
Record an attribute in the attribute registry
- Parameters
name (str) β The name of the attribute to register
val (ParameterBase) β The
ParameterBase
subclass object to register. e.g.Parameter
,SimulationParameter
orState
.
- _register_module(name: str, mod)ο
Add a submodule to the module registry
- Parameters
name (str) β The name of the submodule, extracted from the assigned attribute name
mod (JaxModule) β The submodule to register
- _reset_attribute(name: str) rockpool.nn.modules.module.ModuleBase ο
Reset an attribute to its initialisation value
- Parameters
name (str) β The name of the attribute to reset
- Returns
For compatibility with the functional API
- Return type
self (
Module
)
- _shapeο
The shape of this module
- _spiking_input: boolο
Whether this module receives spiking input
- _spiking_output: boolο
Whether this module produces spiking output
- _submodulenames: List[str]ο
Registry of sub-module names
- _wrap_recorded_state(recorded_dict: dict, t_start: float) Dict[str, rockpool.timeseries.TimeSeries] ο
Convert a recorded dictionary to a
TimeSeries
representationThis method is optional, and is provided to make the
timed()
conversion to aTimedModule
work better. You should override this method in your customModule
, to wrap each element of your recorded state dictionary as aTimeSeries
- Parameters
state_dict (dict) β A recorded state dictionary as returned by
evolve()
t_start (float) β The initial time of the recorded state, to use as the starting point of the time series
- Returns
The mapped recorded state dictionary, wrapped as
TimeSeries
objects- Return type
Dict[str, TimeSeries]
- act_fn: P_Callableο
(Callable) Activation function
- as_graph() rockpool.graph.graph_base.GraphModuleBase [source]ο
Convert this module to a computational graph
- Returns
The computational graph corresponding to this module
- Return type
- Raises
NotImplementedError β If
as_graph()
is not implemented for this subclass
- attributes_named(name: Union[Tuple[str], List[str], str]) dict ο
Search for attributes of this or submodules by time
- Parameters
name (Union[str, Tuple[str]) β The name of the attribute to search for
- Returns
A nested dictionary of attributes that match
name
- Return type
dict
- bias: P_ndarrayο
The vector
(N,)
of bias currents for each unit
- property class_name: strο
Class name of
self
- Type
str
- dt: P_floatο
The Euler solver time step for this module
- evolve(input_data: jax._src.numpy.lax_numpy.ndarray, record: bool = False)[source]ο
Evolve the state of this module over input data
NOTE: THIS MODULE CLASS DOES NOT PROVIDE DOCUMENTATION FOR ITS EVOLVE METHOD. PLEASE UPDATE THE DOCUMENTATION FOR THIS MODULE.
- Parameters
input_data β The input data with shape
(T, size_in)
to evolve withrecord (bool) β If
True
, the module should record internal state during evolution and return the record. IfFalse
, no recording is required. Default:False
.
- Returns
- (output, new_state, record)
output (np.ndarray): The output response of this module with shape
(T, size_out)
new_state (dict): A dictionary containing the updated state of this and all submodules after evolution record (dict): A dictionary containing recorded state of this and all submodules, if requested using therecord
argument
- Return type
tuple
- property full_name: strο
The full name of this module (class plus module name)
- Type
str
- modules() Dict ο
Return a dictionary of all sub-modules of this module
- Returns
A dictionary containing all sub-modules. Each item will be named with the sub-module name.
- Return type
dict
- property name: strο
The name of this module, or an empty string if
None
- Type
str
- noise_std: P_floatο
The std. dev. \(\sigma\) of noise added to internal neuron states at each time step
- parameters(family: Optional[Union[Tuple, List, str]] = None) Dict ο
Return a nested dictionary of module and submodule Parameters
Use this method to inspect the Parameters from this and all submodules. The optional argument
family
allows you to search for Parameters in a particular family β for example"weights"
for all weights of this module and nested submodules.Although the
family
argument is an arbitrary string, reasonable choises are"weights"
,"taus"
for time constants,"biases"
for biasesβ¦Examples
Obtain a dictionary of all Parameters for this module (including submodules):
>>> mod.parameters() dict{ ... }
Obtain a dictionary of Parameters from a particular family:
>>> mod.parameters("weights") dict{ ... }
- Parameters
family (str) β The family of Parameters to search for. Default:
None
; return all parameters.- Returns
A nested dictionary of Parameters of this module and all submodules
- Return type
dict
- reset_parameters()ο
Reset all parameters in this module
- Returns
The updated module is returned for compatibility with the functional API
- Return type
- reset_state() rockpool.nn.modules.jax.jax_module.JaxModule ο
Reset the state of this module
- Returns
The updated module is returned for compatibility with the functional API
- Return type
- set_attributes(new_attributes: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping]) rockpool.nn.modules.jax.jax_module.JaxModule ο
Assign new attributes to this module and submodules
- Parameters
new_attributes (Tree) β The tree of new attributes to assign to this module tree
- Return type
- property shape: tupleο
The shape of this module
- Type
tuple
- simulation_parameters(family: Optional[Union[Tuple, List, str]] = None) Dict ο
Return a nested dictionary of module and submodule SimulationParameters
Use this method to inspect the SimulationParameters from this and all submodules. The optional argument
family
allows you to search for SimulationParameters in a particular family.Examples
Obtain a dictionary of all SimulationParameters for this module (including submodules):
>>> mod.simulation_parameters() dict{ ... }
- Parameters
family (str) β The family of SimulationParameters to search for. Default:
None
; return all SimulationParameter attributes.- Returns
A nested dictionary of SimulationParameters of this module and all submodules
- Return type
dict
- property size: intο
(DEPRECATED) The output size of this module
- Type
int
- property size_in: intο
The input size of this module
- Type
int
- property size_out: intο
The output size of this module
- Type
int
- property spiking_input: boolο
If
True
, this module receives spiking input. IfFalse
, this module expects continuous input.- Type
bool
- property spiking_outputο
If
True
, this module sends spiking output. IfFalse
, this module sends continuous output.- Type
bool
- state(family: Optional[Union[Tuple, List, str]] = None) Dict ο
Return a nested dictionary of module and submodule States
Use this method to inspect the States from this and all submodules. The optional argument
family
allows you to search for States in a particular family.Examples
Obtain a dictionary of all States for this module (including submodules):
>>> mod.state() dict{ ... }
- Parameters
family (str) β The family of States to search for. Default:
None
; return all State attributes.- Returns
A nested dictionary of States of this module and all submodules
- Return type
dict
- tau: P_ndarrayο
The vector
(N,)
of time constants \(\tau\) for each unit
- threshold: P_ndarrayο
(Tensor) Unit thresholds
(Nout,)
or()
- timed(output_num: int = 0, dt: Optional[float] = None, add_events: bool = False)ο
Convert this module to a
TimedModule
- Parameters
output_num (int) β Specify which output of the module to take, if the module returns multiple output series. Default:
0
, take the first (or only) output.dt (float) β Used to provide a time-step for this module, if the module does not already have one. If
self
already defines a time-step, thenself.dt
will be used. Default:None
add_events (bool) β Iff
True
, theTimedModule
will add events occurring on a single timestep on input and output. Default:False
, donβt add time steps.
Returns:
TimedModule
: A timed module that wraps this module
- tree_flatten() Tuple[tuple, tuple] ο
Flatten this module tree for Jax
- classmethod tree_unflatten(aux_data, children)ο
Unflatten a tree of modules from Jax to Rockpool
- w_rec: P_ndarrayο
The recurrent weight matrix
(N, N)
for this module
- x: P_ndarrayο
A vector
(N,)
of the internal state of each unit