nn.modules.LinearJax
- class nn.modules.LinearJax(*args, **kwargs)[source]
Bases:
LinearMixin
,JaxModule
Encapsulates a linear weight matrix, with a Jax backend
Attributes overview
Class name of
self
The full name of this module (class plus module name)
The name of this module, or an empty string if
None
The shape of this module
(DEPRECATED) The output size of this module
The input size of this module
The output size of this module
If
True
, this module receives spiking input.If
True
, this module sends spiking output.Weight matrix of this module
Bias vector of this module
Methods overview
__init__
(shape[, weight, bias, has_bias, ...])Encapsulate a linear weight matrix, with optional biases
as_graph
()Convert this module to a computational graph
attributes_named
(name)Search for attributes of this or submodules by time
evolve
(input_data[, record])Evolve the state of this module over input data
from_graph
(graph)from_graph constructs a LinearMixin object from the comptutational graph
modules
()Return a dictionary of all sub-modules of this module
parameters
([family])Return a nested dictionary of module and submodule Parameters
Reset all parameters in this module
Reset the state of this module
set_attributes
(new_attributes)Assign new attributes to this module and submodules
simulation_parameters
([family])Return a nested dictionary of module and submodule SimulationParameters
state
([family])Return a nested dictionary of module and submodule States
timed
([output_num, dt, add_events])Convert this module to a
TimedModule
Flatten this module tree for Jax
tree_unflatten
(aux_data, children)Unflatten a tree of modules from Jax to Rockpool
- __init__(shape: tuple, weight=None, bias=None, has_bias: bool = False, weight_init_func: ~typing.Callable = <function kaiming>, bias_init_func: ~typing.Callable = <function uniform_sqrt>, *args, **kwargs)
Encapsulate a linear weight matrix, with optional biases
Linear
wraps a single weight matrix, and passes data through by using the matrix as a set of weights. The shape of the matrix must be specified as a tuple(Nin, Nout)
.Linear
provides optional biases.A weight initialisation function may be specified. By default the weights will use Kaiming initialisation (
kaiming()
).A bias initialisation function may be specified, if used. By default the biases will be initialised as uniform random over the range \((-\sqrt(1/N), \sqrt(1/N))\).
Warning
Standard DNN libraries by default include a bias on linear layers. These are usually not used for SNNs, where the bias is configured on the spiking neuron module.
Linear
layers in Rockpool use a default ofhas_bias = False
. You can force the presence of a bias on the linear layer withhas_bias = True
on initialisation.Examples
Build a linear weight matrix with shape
(3, 4)
, and no biases:>>> Linear((3, 4)) Linear with shape (3, 4)
Build a linear weight matrix with shape
(2, 5)
, which will be initialised with zeros:>>> Linear((2, 5), weight_init_func = lambda s: np.zeros(s)) Linear with shape (2, 5)
Provide a concrete initialisation for the linear weights:
>>> Linear((2, 2), weight = np.array([[1, 2], [3, 4]])) Linear with shape (2, 2)
Build a linear layer including biases:
>>> mod = Linear((2, 2), has_bias = True) >>> mod.parameters() Out: {'weight': array([[ 0.56655314, 0.64411151], [-1.43016068, -1.538719 ]]), 'bias': array([-0.58513867, -0.32314069])}
- Parameters:
shape (tuple) – The desired shape of the weight matrix. Must have two entries
(Nin, Nout)
weight_init_func (Callable) – The initialisation function to use for the weights. Default: Kaiming initialization; uniform on the range \((-\sqrt(6/Nin), \sqrt(6/Nin))\)
weight (Optional[np.array]) – A concrete weight matrix to assign to the weights on initialisation.
weight.shape
must match theshape
argumenthas_bias (bool) – A boolean flag indicating that this linear layer should have a bias parameter. Default:
False
, no bias parameterbias_init_func (Callable) – The initialisation function to use for the biases. Default: Uniform / sqrt(N); uniform on the range \((-\sqrt(1/N), \sqrt(1/N))\)
bias (Optional[np.array]) – A concrete bias vector to assign to the biases on initialisation.
bias.shape
must be(N,)
- _abc_impl = <_abc._abc_data object>
- _auto_batch(data: Array, states: Tuple = (), target_shapes: Tuple | None = None) Tuple[Array, Tuple[Array]]
Automatically replicate states over batches and verify input dimensions
- Usage:
>>> data, (state0, state1, state2) = self._auto_batch(data, (self.state0, self.state1, self.state2))
This will verify that
data
has the correct final dimension (i.e.self.size_in
). Ifdata
has only two dimensions(T, Nin)
, then it will be augmented to(1, T, Nin)
. The individual states will be replicated out from shape(a, b, c, ...)
to(n_batches, a, b, c, ...)
and returned.
- Parameters:
data (np.ndarray) – Input data tensor. Either
(batches, T, Nin)
or(T, Nin)
states (Tuple) – Tuple of state variables. Each will be replicated out over batches by prepending a batch dimension
- Returns:
(np.ndarray, Tuple[np.ndarray]) data, states
- static _dot(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: Array | ndarray | bool_ | number | bool | int | float | complex, *, precision: str | Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision] | None = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None) Array
Compute the dot product of two arrays.
JAX implementation of
numpy.dot()
.This differs from
jax.numpy.matmul()
in two respects:if either
a
orb
is a scalar, the result ofdot
is equivalent tojax.numpy.multiply()
, while the result ofmatmul
is an error.if
a
andb
have more than 2 dimensions, the batch indices are stacked rather than broadcast.
- Parameters:
a – first input array, of shape
(..., N)
.b – second input array. Must have shape
(N,)
or(..., N, M)
. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofa
.precision – either
None
(default), which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two such values indicating precision ofa
andb
.preferred_element_type – either
None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
array containing the dot product of the inputs, with batch dimensions of
a
andb
stacked rather than broadcast.
See also
jax.numpy.matmul()
: broadcasted batched matmul.jax.lax.dot_general()
: general batched matrix multiplication.
Examples
For scalar inputs,
dot
computes the element-wise product:>>> x = jnp.array([1, 2, 3]) >>> jnp.dot(x, 2) Array([2, 4, 6], dtype=int32)
For vector or matrix inputs,
dot
computes the vector or matrix product:>>> M = jnp.array([[2, 3, 4], ... [5, 6, 7], ... [8, 9, 0]]) >>> jnp.dot(M, x) Array([20, 38, 26], dtype=int32) >>> jnp.dot(M, M) Array([[ 51, 60, 29], [ 96, 114, 62], [ 61, 78, 95]], dtype=int32)
For higher-dimensional matrix products, batch dimensions are stacked, whereas in
matmul()
they are broadcast. For example:>>> a = jnp.zeros((3, 2, 4)) >>> b = jnp.zeros((3, 4, 1)) >>> jnp.dot(a, b).shape (3, 2, 3, 1) >>> jnp.matmul(a, b).shape (3, 2, 1)
- _force_set_attributes
(bool) If
True
, do not sanity-check attributes when setting.
- _get_attribute_family(type_name: str, family: Tuple | List | str | None = None) dict
Search for attributes of this module and submodules that match a given family
This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string:
"weights"
for weights;"taus"
for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six month’s time).- Parameters:
type_name (str) – The class of parameters to search for. Must be one of
["Parameter", "SimulationParameter", "State"]
or another future subclass ofParameterBase
family (Union[str, Tuple[str]]) – A string or list or tuple of strings, that define one or more attribute families to search for
- Returns:
A nested dictionary of attributes that match the provided
type_name
andfamily
- Return type:
dict
- _get_attribute_registry() Tuple[Dict, Dict]
Return or initialise the attribute registry for this module
- Returns:
registered_attributes, registered_modules
- Return type:
(tuple)
- _has_registered_attribute(name: str) bool
Check if the module has a registered attribute
- Parameters:
name (str) – The name of the attribute to check
- Returns:
True
if the attributename
is in the attribute registry,False
otherwise.- Return type:
bool
- _in_Module_init
(bool) If exists and
True
, indicates that the module is in the__init__
chain.
- _name: str | None
Name of this module, if assigned
- _register_attribute(name: str, val: ParameterBase)
Record an attribute in the attribute registry
- Parameters:
name (str) – The name of the attribute to register
val (ParameterBase) – The
ParameterBase
subclass object to register. e.g.Parameter
,SimulationParameter
orState
.
- _register_module(name: str, mod)
Add a submodule to the module registry
- Parameters:
name (str) – The name of the submodule, extracted from the assigned attribute name
mod (JaxModule) – The submodule to register
- _reset_attribute(name: str) ModuleBase
Reset an attribute to its initialisation value
- Parameters:
name (str) – The name of the attribute to reset
- Returns:
For compatibility with the functional API
- Return type:
self (
Module
)
- _shape
The shape of this module
- _spiking_input: bool
Whether this module receives spiking input
- _spiking_output: bool
Whether this module produces spiking output
- _submodulenames: List[str]
Registry of sub-module names
- _wrap_recorded_state(recorded_dict: dict, t_start: float) Dict[str, TimeSeries]
Convert a recorded dictionary to a
TimeSeries
representationThis method is optional, and is provided to make the
timed()
conversion to aTimedModule
work better. You should override this method in your customModule
, to wrap each element of your recorded state dictionary as aTimeSeries
- Parameters:
state_dict (dict) – A recorded state dictionary as returned by
evolve()
t_start (float) – The initial time of the recorded state, to use as the starting point of the time series
- Returns:
The mapped recorded state dictionary, wrapped as
TimeSeries
objects- Return type:
Dict[str, TimeSeries]
- as_graph() GraphModuleBase
Convert this module to a computational graph
- Returns:
The computational graph corresponding to this module
- Return type:
- Raises:
NotImplementedError – If
as_graph()
is not implemented for this subclass
- attributes_named(name: Tuple[str] | List[str] | str) dict
Search for attributes of this or submodules by time
- Parameters:
name (Union[str, Tuple[str]) – The name of the attribute to search for
- Returns:
A nested dictionary of attributes that match
name
- Return type:
dict
- bias
Bias vector of this module
- property class_name: str
Class name of
self
- Type:
str
- evolve(input_data, record: bool = False) Tuple[Any, Any, Any]
Evolve the state of this module over input data
NOTE: THIS MODULE CLASS DOES NOT PROVIDE DOCUMENTATION FOR ITS EVOLVE METHOD. PLEASE UPDATE THE DOCUMENTATION FOR THIS MODULE.
- Parameters:
input_data – The input data with shape
(T, size_in)
to evolve withrecord (bool) – If
True
, the module should record internal state during evolution and return the record. IfFalse
, no recording is required. Default:False
.
- Returns:
- (output, new_state, record)
output (np.ndarray): The output response of this module with shape
(T, size_out)
new_state (dict): A dictionary containing the updated state of this and all submodules after evolution record (dict): A dictionary containing recorded state of this and all submodules, if requested using therecord
argument
- Return type:
tuple
- classmethod from_graph(graph: LinearWeights) LinearMixin
from_graph constructs a LinearMixin object from the comptutational graph
- Parameters:
graph (LinearWeights) – the reference computational graph to restore computational module
- Returns:
a LinearMixin object
- Return type:
LinearMixin
- property full_name: str
The full name of this module (class plus module name)
- Type:
str
- modules() Dict
Return a dictionary of all sub-modules of this module
- Returns:
A dictionary containing all sub-modules. Each item will be named with the sub-module name.
- Return type:
dict
- property name: str
The name of this module, or an empty string if
None
- Type:
str
- parameters(family: Tuple | List | str | None = None) Dict
Return a nested dictionary of module and submodule Parameters
Use this method to inspect the Parameters from this and all submodules. The optional argument
family
allows you to search for Parameters in a particular family — for example"weights"
for all weights of this module and nested submodules.Although the
family
argument is an arbitrary string, reasonable choises are"weights"
,"taus"
for time constants,"biases"
for biases…Examples
Obtain a dictionary of all Parameters for this module (including submodules):
>>> mod.parameters() dict{ ... }
Obtain a dictionary of Parameters from a particular family:
>>> mod.parameters("weights") dict{ ... }
- Parameters:
family (str) – The family of Parameters to search for. Default:
None
; return all parameters.- Returns:
A nested dictionary of Parameters of this module and all submodules
- Return type:
dict
- reset_parameters()
Reset all parameters in this module
- Returns:
The updated module is returned for compatibility with the functional API
- Return type:
- reset_state() JaxModule
Reset the state of this module
- Returns:
The updated module is returned for compatibility with the functional API
- Return type:
- set_attributes(new_attributes: Iterable | MutableMapping | Mapping) JaxModule
Assign new attributes to this module and submodules
- Parameters:
new_attributes (Tree) – The tree of new attributes to assign to this module tree
- Return type:
- property shape: tuple
The shape of this module
- Type:
tuple
- simulation_parameters(family: Tuple | List | str | None = None) Dict
Return a nested dictionary of module and submodule SimulationParameters
Use this method to inspect the SimulationParameters from this and all submodules. The optional argument
family
allows you to search for SimulationParameters in a particular family.Examples
Obtain a dictionary of all SimulationParameters for this module (including submodules):
>>> mod.simulation_parameters() dict{ ... }
- Parameters:
family (str) – The family of SimulationParameters to search for. Default:
None
; return all SimulationParameter attributes.- Returns:
A nested dictionary of SimulationParameters of this module and all submodules
- Return type:
dict
- property size: int
(DEPRECATED) The output size of this module
- Type:
int
- property size_in: int
The input size of this module
- Type:
int
- property size_out: int
The output size of this module
- Type:
int
- property spiking_input: bool
If
True
, this module receives spiking input. IfFalse
, this module expects continuous input.- Type:
bool
- property spiking_output
If
True
, this module sends spiking output. IfFalse
, this module sends continuous output.- Type:
bool
- state(family: Tuple | List | str | None = None) Dict
Return a nested dictionary of module and submodule States
Use this method to inspect the States from this and all submodules. The optional argument
family
allows you to search for States in a particular family.Examples
Obtain a dictionary of all States for this module (including submodules):
>>> mod.state() dict{ ... }
- Parameters:
family (str) – The family of States to search for. Default:
None
; return all State attributes.- Returns:
A nested dictionary of States of this module and all submodules
- Return type:
dict
- timed(output_num: int = 0, dt: float | None = None, add_events: bool = False)
Convert this module to a
TimedModule
- Parameters:
output_num (int) – Specify which output of the module to take, if the module returns multiple output series. Default:
0
, take the first (or only) output.dt (float) – Used to provide a time-step for this module, if the module does not already have one. If
self
already defines a time-step, thenself.dt
will be used. Default:None
add_events (bool) – Iff
True
, theTimedModule
will add events occurring on a single timestep on input and output. Default:False
, don’t add time steps.
Returns:
TimedModule
: A timed module that wraps this module
- tree_flatten() Tuple[tuple, tuple]
Flatten this module tree for Jax
- classmethod tree_unflatten(aux_data, children)
Unflatten a tree of modules from Jax to Rockpool
- weight
Weight matrix of this module