nn.modules.LinearJax

class nn.modules.LinearJax(*args, **kwargs)[source]

Bases: LinearMixin, JaxModule

Encapsulates a linear weight matrix, with a Jax backend

Attributes overview

class_name

Class name of self

full_name

The full name of this module (class plus module name)

name

The name of this module, or an empty string if None

shape

The shape of this module

size

(DEPRECATED) The output size of this module

size_in

The input size of this module

size_out

The output size of this module

spiking_input

If True, this module receives spiking input.

spiking_output

If True, this module sends spiking output.

weight

Weight matrix of this module

bias

Bias vector of this module

Methods overview

__init__(shape[, weight, bias, has_bias, ...])

Encapsulate a linear weight matrix, with optional biases

as_graph()

Convert this module to a computational graph

attributes_named(name)

Search for attributes of this or submodules by time

evolve(input_data[, record])

Evolve the state of this module over input data

from_graph(graph)

from_graph constructs a LinearMixin object from the comptutational graph

modules()

Return a dictionary of all sub-modules of this module

parameters([family])

Return a nested dictionary of module and submodule Parameters

reset_parameters()

Reset all parameters in this module

reset_state()

Reset the state of this module

set_attributes(new_attributes)

Assign new attributes to this module and submodules

simulation_parameters([family])

Return a nested dictionary of module and submodule SimulationParameters

state([family])

Return a nested dictionary of module and submodule States

timed([output_num, dt, add_events])

Convert this module to a TimedModule

tree_flatten()

Flatten this module tree for Jax

tree_unflatten(aux_data, children)

Unflatten a tree of modules from Jax to Rockpool

__init__(shape: tuple, weight=None, bias=None, has_bias: bool = False, weight_init_func: ~typing.Callable = <function kaiming>, bias_init_func: ~typing.Callable = <function uniform_sqrt>, *args, **kwargs)

Encapsulate a linear weight matrix, with optional biases

Linear wraps a single weight matrix, and passes data through by using the matrix as a set of weights. The shape of the matrix must be specified as a tuple (Nin, Nout). Linear provides optional biases.

A weight initialisation function may be specified. By default the weights will use Kaiming initialisation (kaiming()).

A bias initialisation function may be specified, if used. By default the biases will be initialised as uniform random over the range \((-\sqrt(1/N), \sqrt(1/N))\).

Warning

Standard DNN libraries by default include a bias on linear layers. These are usually not used for SNNs, where the bias is configured on the spiking neuron module. Linear layers in Rockpool use a default of has_bias = False. You can force the presence of a bias on the linear layer with has_bias = True on initialisation.

Examples

Build a linear weight matrix with shape (3, 4), and no biases:

>>> Linear((3, 4))
Linear  with shape (3, 4)

Build a linear weight matrix with shape (2, 5), which will be initialised with zeros:

>>> Linear((2, 5), weight_init_func = lambda s: np.zeros(s))
Linear  with shape (2, 5)

Provide a concrete initialisation for the linear weights:

>>> Linear((2, 2), weight = np.array([[1, 2], [3, 4]]))
Linear  with shape (2, 2)

Build a linear layer including biases:

>>> mod = Linear((2, 2), has_bias = True)
>>> mod.parameters()
Out:
{'weight': array([[ 0.56655314,  0.64411151],
        [-1.43016068, -1.538719  ]]),
 'bias': array([-0.58513867, -0.32314069])}
Parameters:
  • shape (tuple) – The desired shape of the weight matrix. Must have two entries (Nin, Nout)

  • weight_init_func (Callable) – The initialisation function to use for the weights. Default: Kaiming initialization; uniform on the range \((-\sqrt(6/Nin), \sqrt(6/Nin))\)

  • weight (Optional[np.array]) – A concrete weight matrix to assign to the weights on initialisation. weight.shape must match the shape argument

  • has_bias (bool) – A boolean flag indicating that this linear layer should have a bias parameter. Default: False, no bias parameter

  • bias_init_func (Callable) – The initialisation function to use for the biases. Default: Uniform / sqrt(N); uniform on the range \((-\sqrt(1/N), \sqrt(1/N))\)

  • bias (Optional[np.array]) – A concrete bias vector to assign to the biases on initialisation. bias.shape must be (N,)

_abc_impl = <_abc._abc_data object>
_auto_batch(data: Array, states: Tuple = (), target_shapes: Tuple | None = None) Tuple[Array, Tuple[Array]]

Automatically replicate states over batches and verify input dimensions

Usage:
>>> data, (state0, state1, state2) = self._auto_batch(data, (self.state0, self.state1, self.state2))

This will verify that data has the correct final dimension (i.e. self.size_in). If data has only two dimensions (T, Nin), then it will be augmented to (1, T, Nin). The individual states will be replicated out from shape (a, b, c, ...) to (n_batches, a, b, c, ...) and returned.

Parameters:
  • data (np.ndarray) – Input data tensor. Either (batches, T, Nin) or (T, Nin)

  • states (Tuple) – Tuple of state variables. Each will be replicated out over batches by prepending a batch dimension

Returns:

(np.ndarray, Tuple[np.ndarray]) data, states

static _dot(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: Array | ndarray | bool_ | number | bool | int | float | complex, *, precision: str | Precision | tuple[str, str] | tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision] | None = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None) Array

Compute the dot product of two arrays.

JAX implementation of numpy.dot().

This differs from jax.numpy.matmul() in two respects:

  • if either a or b is a scalar, the result of dot is equivalent to jax.numpy.multiply(), while the result of matmul is an error.

  • if a and b have more than 2 dimensions, the batch indices are stacked rather than broadcast.

Parameters:
  • a – first input array, of shape (..., N).

  • b – second input array. Must have shape (N,) or (..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of a.

  • precision – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Returns:

array containing the dot product of the inputs, with batch dimensions of a and b stacked rather than broadcast.

See also

  • jax.numpy.matmul(): broadcasted batched matmul.

  • jax.lax.dot_general(): general batched matrix multiplication.

Examples

For scalar inputs, dot computes the element-wise product:

>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)

For vector or matrix inputs, dot computes the vector or matrix product:

>>> M = jnp.array([[2, 3, 4],
...                [5, 6, 7],
...                [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51,  60,  29],
       [ 96, 114,  62],
       [ 61,  78,  95]], dtype=int32)

For higher-dimensional matrix products, batch dimensions are stacked, whereas in matmul() they are broadcast. For example:

>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)
_force_set_attributes

(bool) If True, do not sanity-check attributes when setting.

_get_attribute_family(type_name: str, family: Tuple | List | str | None = None) dict

Search for attributes of this module and submodules that match a given family

This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string: "weights" for weights; "taus" for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six month’s time).

Parameters:
  • type_name (str) – The class of parameters to search for. Must be one of ["Parameter", "SimulationParameter", "State"] or another future subclass of ParameterBase

  • family (Union[str, Tuple[str]]) – A string or list or tuple of strings, that define one or more attribute families to search for

Returns:

A nested dictionary of attributes that match the provided type_name and family

Return type:

dict

_get_attribute_registry() Tuple[Dict, Dict]

Return or initialise the attribute registry for this module

Returns:

registered_attributes, registered_modules

Return type:

(tuple)

_has_registered_attribute(name: str) bool

Check if the module has a registered attribute

Parameters:

name (str) – The name of the attribute to check

Returns:

True if the attribute name is in the attribute registry, False otherwise.

Return type:

bool

_in_Module_init

(bool) If exists and True, indicates that the module is in the __init__ chain.

_name: str | None

Name of this module, if assigned

_register_attribute(name: str, val: ParameterBase)

Record an attribute in the attribute registry

Parameters:
  • name (str) – The name of the attribute to register

  • val (ParameterBase) – The ParameterBase subclass object to register. e.g. Parameter, SimulationParameter or State.

_register_module(name: str, mod)

Add a submodule to the module registry

Parameters:
  • name (str) – The name of the submodule, extracted from the assigned attribute name

  • mod (JaxModule) – The submodule to register

_reset_attribute(name: str) ModuleBase

Reset an attribute to its initialisation value

Parameters:

name (str) – The name of the attribute to reset

Returns:

For compatibility with the functional API

Return type:

self (Module)

_rockpool_pytree_registry = []

The internal registry of registered JaxModule s

_shape

The shape of this module

_spiking_input: bool

Whether this module receives spiking input

_spiking_output: bool

Whether this module produces spiking output

_submodulenames: List[str]

Registry of sub-module names

_wrap_recorded_state(recorded_dict: dict, t_start: float) Dict[str, TimeSeries]

Convert a recorded dictionary to a TimeSeries representation

This method is optional, and is provided to make the timed() conversion to a TimedModule work better. You should override this method in your custom Module, to wrap each element of your recorded state dictionary as a TimeSeries

Parameters:
  • state_dict (dict) – A recorded state dictionary as returned by evolve()

  • t_start (float) – The initial time of the recorded state, to use as the starting point of the time series

Returns:

The mapped recorded state dictionary, wrapped as TimeSeries objects

Return type:

Dict[str, TimeSeries]

as_graph() GraphModuleBase

Convert this module to a computational graph

Returns:

The computational graph corresponding to this module

Return type:

GraphModuleBase

Raises:

NotImplementedError – If as_graph() is not implemented for this subclass

attributes_named(name: Tuple[str] | List[str] | str) dict

Search for attributes of this or submodules by time

Parameters:

name (Union[str, Tuple[str]) – The name of the attribute to search for

Returns:

A nested dictionary of attributes that match name

Return type:

dict

bias

Bias vector of this module

property class_name: str

Class name of self

Type:

str

evolve(input_data, record: bool = False) Tuple[Any, Any, Any]

Evolve the state of this module over input data

NOTE: THIS MODULE CLASS DOES NOT PROVIDE DOCUMENTATION FOR ITS EVOLVE METHOD. PLEASE UPDATE THE DOCUMENTATION FOR THIS MODULE.

Parameters:
  • input_data – The input data with shape (T, size_in) to evolve with

  • record (bool) – If True, the module should record internal state during evolution and return the record. If False, no recording is required. Default: False.

Returns:

(output, new_state, record)

output (np.ndarray): The output response of this module with shape (T, size_out) new_state (dict): A dictionary containing the updated state of this and all submodules after evolution record (dict): A dictionary containing recorded state of this and all submodules, if requested using the record argument

Return type:

tuple

classmethod from_graph(graph: LinearWeights) LinearMixin

from_graph constructs a LinearMixin object from the comptutational graph

Parameters:

graph (LinearWeights) – the reference computational graph to restore computational module

Returns:

a LinearMixin object

Return type:

LinearMixin

property full_name: str

The full name of this module (class plus module name)

Type:

str

modules() Dict

Return a dictionary of all sub-modules of this module

Returns:

A dictionary containing all sub-modules. Each item will be named with the sub-module name.

Return type:

dict

property name: str

The name of this module, or an empty string if None

Type:

str

parameters(family: Tuple | List | str | None = None) Dict

Return a nested dictionary of module and submodule Parameters

Use this method to inspect the Parameters from this and all submodules. The optional argument family allows you to search for Parameters in a particular family — for example "weights" for all weights of this module and nested submodules.

Although the family argument is an arbitrary string, reasonable choises are "weights", "taus" for time constants, "biases" for biases…

Examples

Obtain a dictionary of all Parameters for this module (including submodules):

>>> mod.parameters()
dict{ ... }

Obtain a dictionary of Parameters from a particular family:

>>> mod.parameters("weights")
dict{ ... }
Parameters:

family (str) – The family of Parameters to search for. Default: None; return all parameters.

Returns:

A nested dictionary of Parameters of this module and all submodules

Return type:

dict

reset_parameters()

Reset all parameters in this module

Returns:

The updated module is returned for compatibility with the functional API

Return type:

Module

reset_state() JaxModule

Reset the state of this module

Returns:

The updated module is returned for compatibility with the functional API

Return type:

Module

set_attributes(new_attributes: Iterable | MutableMapping | Mapping) JaxModule

Assign new attributes to this module and submodules

Parameters:

new_attributes (Tree) – The tree of new attributes to assign to this module tree

Return type:

JaxModule

property shape: tuple

The shape of this module

Type:

tuple

simulation_parameters(family: Tuple | List | str | None = None) Dict

Return a nested dictionary of module and submodule SimulationParameters

Use this method to inspect the SimulationParameters from this and all submodules. The optional argument family allows you to search for SimulationParameters in a particular family.

Examples

Obtain a dictionary of all SimulationParameters for this module (including submodules):

>>> mod.simulation_parameters()
dict{ ... }
Parameters:

family (str) – The family of SimulationParameters to search for. Default: None; return all SimulationParameter attributes.

Returns:

A nested dictionary of SimulationParameters of this module and all submodules

Return type:

dict

property size: int

(DEPRECATED) The output size of this module

Type:

int

property size_in: int

The input size of this module

Type:

int

property size_out: int

The output size of this module

Type:

int

property spiking_input: bool

If True, this module receives spiking input. If False, this module expects continuous input.

Type:

bool

property spiking_output

If True, this module sends spiking output. If False, this module sends continuous output.

Type:

bool

state(family: Tuple | List | str | None = None) Dict

Return a nested dictionary of module and submodule States

Use this method to inspect the States from this and all submodules. The optional argument family allows you to search for States in a particular family.

Examples

Obtain a dictionary of all States for this module (including submodules):

>>> mod.state()
dict{ ... }
Parameters:

family (str) – The family of States to search for. Default: None; return all State attributes.

Returns:

A nested dictionary of States of this module and all submodules

Return type:

dict

timed(output_num: int = 0, dt: float | None = None, add_events: bool = False)

Convert this module to a TimedModule

Parameters:
  • output_num (int) – Specify which output of the module to take, if the module returns multiple output series. Default: 0, take the first (or only) output.

  • dt (float) – Used to provide a time-step for this module, if the module does not already have one. If self already defines a time-step, then self.dt will be used. Default: None

  • add_events (bool) – Iff True, the TimedModule will add events occurring on a single timestep on input and output. Default: False, don’t add time steps.

Returns: TimedModule: A timed module that wraps this module

tree_flatten() Tuple[tuple, tuple]

Flatten this module tree for Jax

classmethod tree_unflatten(aux_data, children)

Unflatten a tree of modules from Jax to Rockpool

weight

Weight matrix of this module