This page was generated from docs/in-depth/api-low-level.ipynb. Interactive online version:
🛠 Low-level Module
API
The low-level API in Rockpool is designed for minimal efficient implementation of stateful neural networks.
The Module
base class provides facilities for configuring, simulating and examining networks of stateful neurons.
Constructing a Module
All Module
subclasses accept minimally a shape
argument on construction. This should specify the input, output and internal dimensionality of the Module
completely, so that the code can determine how many neurons should be generated, and the sizes of the state variables and parameters.
Some Module
subclasses allow you to specify the module shape by setting concrete parameter arrays, e.g. by setting a vector of length (N,)
as the bias parameters for a set of neurons. These concrete parameter values will be used to initialise the Module
, and if the Module
is reset, then the parameters will return to those concrete values.
Otherwise, all Module
subclasses will set reasonable default initialisation values for the parameters.
[1]:
# - Switch off warnings
import warnings
warnings.filterwarnings("ignore")
# - Useful imports
try:
from rich import print
except:
pass
# - Example of constructing a module
from rockpool.nn.modules import Rate
import numpy as np
# - Construct a Module with 4 neurons
mod = Rate(4)
print(mod)
Rate with shape (4,)
[2]:
# - Construct a Module with concrete parameters
mod = Rate(4, tau=np.ones(4))
print(mod)
Rate with shape (4,)
Evolving a Module
You evolve the state of a Module
by simply calling it. Module
subclasses expect clocked raterised data as numpy
arrays with shape (T, Nin)
or (batches, T, Nin)
. batches
is the number of batches; T
is the number of time steps, and Nin
is the input size of the module mod.size_in
.
Calling a Module
has the following syntax:
output, new_state, recorded_state = mod(input: np.array, record: bool = False)
As a result of calling the Module
, the output of the module is returned as a numpy
array with shape (batches, T, Nout)
. Here Nout
is the output size of the module module.size_out
.
new_state
will be a state dictionary containing the final state of the module, and all submodules, at the end of evolution. This will become more relevant when using the functional API (see [𝝺] Low-level functional API).
recorded_state
is only requested if the argument record = True
is passed to the module. In that case recorded_state
will be a nested dictionary containing the recorded state of the module and all submodules. Each element in recorded_state
should have shape (T, ...)
, where T
is the number of evolution timesteps and the following dimensions are whatever appropriate for that state variable.
[3]:
# - Generate and evolve over some input
T = 5
input = np.random.rand(T, mod.size_in)
output, _, _ = mod(input)
print(f"Output shape: {output.shape}")
Output shape: (1, 5, 4)
[4]:
# - Request the recorded state
output, _, recorded_state = mod(input, record=True)
print("Parameters:", recorded_state)
Parameters: { 'rec_input': array([[[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]), 'x': array([[[0.00447242, 0.00335577, 0.00469451, 0.00364505], [0.00536816, 0.00415351, 0.00567921, 0.00439317], [0.00544684, 0.00446443, 0.00618463, 0.00441146], [0.00637465, 0.00499181, 0.00717 , 0.00518635], [0.00711199, 0.00548246, 0.00750611, 0.00532915]]]) }
Parameters, State and SimulationParameters
Rockpool defines three types of parameters for Module
s: Parameter
, State
and SimulationParameter
.
Parameter
s are roughly any parameter that you would consider part of the configuration of a network. If you need to tell someone else how to specify your network (without going into details of simulation backend), you tell them about your Parameter
s. Often the set of Parameter
s will be the trainable parameters of a network.
State
s are any internal values that need to be maintained to track how the neurons, synapses, whatever in the dynamical system of a Module
evolve over time. This could comprise neuron membrane potentials; synaptic currents; etc.
SimulationParameter
s are attributes that need to be specified for simulation purposes, but which shouldn’t directly affect the network output and behaviour in theory. For example, the time-step dt
of a Module
is required for a forward Euler ODE solver, but the network configuration should be valid and usable regardless of what dt
is set to. And you shouldn’t need to specify the dt
when telling someone else about your network configuration.
One more useful wrapper class is Constant
. You should use this to wrap any model parameters that you want to force not to be trainable.
These classes are defined in rockpool.parameters
.
Building a network with Module
s
The build a complex network in Rockpool, you need to define your own Module
subclass. Module
takes care of many things for you, allowing you to define a network architecture without much overhead.
Minimally you need to define an Module.__init__()
method, which specifies network parameters (e.g. weights) and whichever submodules are required for your network. The submodules take over the job of defining their own parameters and states.
You also need to define an Module.evolve()
method, which contains the “plumbing” of your network. This method specifies how data is passed through your network, between submodules, and out again.
We’ll build a simple FFwd layer containing some weights and a set of neurons.
Note that this simple example doesn’t return the updated module state and recorded state properly.
[5]:
# - Build a simple network
from rockpool.nn.modules import Module
from rockpool.parameters import Parameter
from rockpool.nn.modules import RateJax
class ffwd_net(Module):
# - Provide an `__init__` method to specify required parameters and modules
# Here you check, define and initialise whatever parameters and
# state you need for your module.
def __init__(
self,
shape,
*args,
**kwargs,
):
# - Call superclass initialisation
# This is always required for a `Module` class
super().__init__(shape=shape, *args, **kwargs)
# - Specify weights attribute
# We need a weights matrix for our input weights.
# We specify the shape explicitly, and provide an initialisation function.
# We also specify a family for the parameter, "weights". This is used to
# query parameters conveniently, and is a good idea to provide.
self.w_ffwd = Parameter(
shape=self.shape,
init_func=lambda s: np.zeros(s),
family="weights",
)
# - Specify and a add submodule
# These will be the neurons in our layer, to receive the weighted
# input signals. This sub-module will be automatically configured
# internally, to specify the required state and parameters
self.neurons = RateJax(self.shape[-1])
# - The `evolve` method contains the internal logic of your module
# `evolve` takes care of passing data in and out of the module,
# and between sub-modules if present.
def evolve(self, input_data, *args, **kwargs):
# - Pass input data through the input weights
x = input_data @ self.w_ffwd
# - Pass the signals through the neurons
x, _, _ = self.neurons(x)
# - Return the module output
return x, {}, {}
Writing an evolve()
method that returns state and record
To adhere to the Module
API, your Module.evolve()
method must return the updated set of states after evolution, and must support recording internal states if requested. The example below replaces the Module.evolve()
method for the network above, illustrating how to conveniently do this.
[6]:
def evolve(self, input_data, record: bool = False, *args, **kwargs):
# - Initialise state and record dictionaries
new_state = {}
recorded_state = {}
# - Pass input data through the input weights
x = input_data @ self.w_ffwd
# - Add an internal signal record to the record dictionary
if record:
recorded_state["weighted_input"] = x
# - Pass the signals through the neurons, passing through the `record` argument
x, submod_state, submod_record = self.neurons(x, record=record)
# - Record the submodule state
new_state.update("neurons", submod_state)
# - Include the recorded state
recorded_state.update("neurons", submod_record)
# - Return the module output
return x, new_state, recorded_state
Inspecting a Module
You can examine the internal parameters and state of a Module
using a set of convenient inspection methods parameters()
, state()
and simulation_parameters()
.
params: dict = mod.parameters(family: str = None)
state: dict = mod.state(family: str = None)
simulation_parameters: dict = mod.simulation_parameters(family: str = None)
In each case the method returns a nested dictionary containins all registered attributes for the module and all submodules.
[7]:
# - Build a module for our network
my_mod = ffwd_net((4, 6))
print(my_mod)
ffwd_net with shape (4, 6) { RateJax 'neurons' with shape (6,) }
[8]:
# - Show module parameters
print("Parameters:", my_mod.parameters())
Parameters: { 'w_ffwd': array([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]), 'neurons': { 'tau': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32), 'bias': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), 'threshold': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32) } }
[9]:
# - Show module state
print("State:", my_mod.state())
State: { 'neurons': { 'rng_key': DeviceArray([1251626347, 511538859], dtype=uint32), 'x': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32) } }
[10]:
# - Return parameters from particular families
print("Module time constants:", my_mod.parameters("taus"))
print("Module weights:", my_mod.parameters("weights"))
Module time constants: {'neurons': {'tau': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32)}}
Module weights: { 'w_ffwd': array([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]) }
You can of course access all attributes of a Module
directly using standard Python “dot” indexing syntax:
[11]:
# - Access parameters directly
print(".w_ffwd:", my_mod.w_ffwd)
print(".neurons.tau:", my_mod.neurons.tau)
.w_ffwd: [[0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0.]]
.neurons.tau: [0.02 0.02 0.02 0.02 0.02 0.02]
Module
API reference
Every Module
provides the following attributes:
Attribute |
Description |
---|---|
The name of the subclass |
|
The attribute name that this module was assigned to. Will be None for a base-level module |
|
The class name and module name together. Useful for printing |
|
If |
|
If |
|
The dimensions of the module. Can have any number of entries, for complex modules. |
|
The number of input channels the module expects |
|
The number of output channels the module produces |
Every Module
provides the following methods:
Method |
Description |
---|---|
Return a nested dictionary of module parameters, optionally restricting the search to a particular family of parameters such as weights |
|
Return a nested dictionary of module state |
|
Return a nested dictionary of module simulation parameters |
|
Return a list of submodules of this module |
|
Search for and return nested attributes matching a particular name |
|
Set the parameter values for this and nested submodules |
|
Reset the state of this and nested submodules |
|
Reset the parameters of this and nested submodules to their initialisation defaults |
|
Utility method to assist with handling batched data |
|
Convert this module to the high-level |