This page was generated from docs/in-depth/api-functional.ipynb. Interactive online version:
[𝝺] Low-level functional API
Rockpool Module
s and the JaxModule
base class support a functional form for manipulating parameters and for evolution. This is particularly important when using Jax, since this library requires a functional programming style.
Functional evolution
First let’s set up a module to play with:
[1]:
# - Switch off warnings
import warnings
warnings.filterwarnings("ignore")
# - Rockpool imports
from rockpool.nn.modules import RateJax
# - Other useful imports
import numpy as np
try:
from rich import print
except:
pass
# - Construct a module
N = 3
mod = RateJax(N)
Now if we evolve the module, we get the outputs we expect:
[2]:
# - Set up some input
T = 10
input = np.random.rand(T, N)
output, new_state, record = mod(input)
[3]:
print("output:", output)
output: [[[0.01076346 0.02204564 0.02956902] [0.04010231 0.06479559 0.07025937] [0.0432203 0.06293815 0.11122491] [0.06027619 0.08317991 0.1307195 ] [0.05882018 0.09205481 0.15850767] [0.08602361 0.12148949 0.16591538] [0.10797263 0.11710763 0.18505278] [0.13634151 0.13502166 0.20233528] [0.15542242 0.15269144 0.20374872] [0.16027772 0.15767853 0.19997193]]]
[4]:
print("new_state:", new_state)
new_state: { 'x': DeviceArray([0.16027772, 0.15767853, 0.19997193], dtype=float32), 'rng_key': DeviceArray([2469880657, 3700232383], dtype=uint32) }
So far so good. The issue with jax
is that jit
-compiled modules and functions cannot have side-effects. For Rockpool, evolution almost always has side-effects, in terms of updating the internal state variables of each module.
In the case of the evolution above, we can see that the internal state was not updated during evolution:
[5]:
print("mod.state:", mod.state())
print(mod.state()["x"], " != ", new_state["x"])
mod.state: { 'rng_key': DeviceArray([ 237268104, 2681681569], dtype=uint32), 'x': DeviceArray([0., 0., 0.], dtype=float32) }
[0. 0. 0.] != [0.16027772 0.15767853 0.19997193]
The correct resolution to this is to assign new_state
to the module atfer each evolution:
[6]:
mod = mod.set_attributes(new_state)
print(mod.state()["x"], " == ", new_state["x"])
[0.16027772 0.15767853 0.19997193] == [0.16027772 0.15767853 0.19997193]
You will have noticed the functional form of the call to set_attributes()
above. This is addressed in the next section.
Functional state and attribute setting
Direct attribute assignment works at the top level, using standard Python syntax:
[7]:
new_tau = mod.tau * 0.4
mod.tau = new_tau
print(new_tau, " == ", mod.tau)
[0.008 0.008 0.008] == [0.008 0.008 0.008]
A functional form is also supported, via the set_attributes()
method. Here a copy of the module (and submodules) is returned, to replace the “old” module with one with updated attributes:
[8]:
params = mod.parameters()
params["tau"] = params["tau"] * 3.0
# - Note the functional calling style
mod = mod.set_attributes(params)
# - check that the attribute was set
print(params["tau"], " == ", mod.tau)
[0.024 0.024 0.024] == [0.024 0.024 0.024]
Functional module reset
Resetting the module state and parameters also must be done using a functional form:
[9]:
# - Reset the module state
mod = mod.reset_state()
# - Reset the module parameters
mod = mod.reset_parameters()
Jax flattening
JaxModule
provides the methods tree_flatten()
and tree_unflatten()
, which are required to serialise and deserialise modules for Jax compilation and execution.
If you write a JaxModule
subclass, it will be automatically registered with Jax as a pytree
. You shouldn’t need to override tree_flatten()
or tree_unflatten()
in your modules.
Flattening and unflattening requires that your __init__()
method must be callable with only a shape as input, which should be sufficient to specify the network architecture of your module and all submodules.
If that isn’t the case, then you may need to override tree_flatten()
and tree_unflatten()
.