This page was generated from docs/in-depth/api-functional.ipynb.
Interactive online version:
[πΊ] Low-level functional APIο
Rockpool Module
s and the JaxModule
base class support a functional form for manipulating parameters and for evolution. This is particularly important when using Jax, since this library requires a functional programming style.
Functional evolutionο
First letβs set up a module to play with:
[1]:
# - Switch off warnings
import warnings
warnings.filterwarnings("ignore")
# - Rockpool imports
from rockpool.nn.modules import RateJax
# - Other useful imports
import numpy as np
try:
from rich import print
except:
pass
# - Construct a module
N = 3
mod = RateJax(N)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Could not import package: No module named 'iaf_nest'
Now if we evolve the module, we get the outputs we expect:
[2]:
# - Set up some input
T = 10
input = np.random.rand(T, N)
output, new_state, record = mod(input)
[3]:
print("output:", output)
output: [[[0.02114696 0.00989053 0.03182354] [0.04654327 0.03433599 0.06666339] [0.08230176 0.05535971 0.07213248] [0.12332115 0.05678442 0.08073001] [0.13181739 0.09974151 0.07767531] [0.13560757 0.09651617 0.09148949] [0.1767954 0.10431428 0.11117382] [0.21250004 0.1332448 0.12745431] [0.23403281 0.15806702 0.12969883] [0.23632057 0.19107696 0.1703079 ]]]
[4]:
print("new_state:", new_state)
new_state: { 'x': DeviceArray([0.23632057, 0.19107696, 0.1703079 ], dtype=float32), 'rng_key': array([ 808257824, 3588456238], dtype=uint32) }
So far so good. The issue with jax
is that jit
-compiled modules and functions cannot have side-effects. For Rockpool, evolution almost always has side-effects, in terms of updating the internal state variables of each module.
In the case of the evolution above, we can see that the internal state was not updated during evolution:
[5]:
print("mod.state:", mod.state())
print(mod.state()["x"], " != ", new_state["x"])
mod.state: { 'rng_key': array([1432838885, 1403173352], dtype=uint32), 'x': DeviceArray([0., 0., 0.], dtype=float32) }
[0. 0. 0.] != [0.23632057 0.19107696 0.1703079 ]
The correct resolution to this is to assign new_state
to the module atfer each evolution:
[6]:
mod = mod.set_attributes(new_state)
print(mod.state()["x"], " == ", new_state["x"])
[0.23632057 0.19107696 0.1703079 ] == [0.23632057 0.19107696 0.1703079 ]
You will have noticed the functional form of the call to set_attributes()
above. This is addressed in the next section.
Functional state and attribute settingο
Direct attribute assignment works at the top level, using standard Python syntax:
[7]:
new_tau = mod.tau * 0.4
mod.tau = new_tau
print(new_tau, " == ", mod.tau)
[0.008 0.008 0.008] == [0.008 0.008 0.008]
A functional form is also supported, via the set_attributes()
method. Here a copy of the module (and submodules) is returned, to replace the βoldβ module with one with updated attributes:
[8]:
params = mod.parameters()
params["tau"] = params["tau"] * 3.0
# - Note the functional calling style
mod = mod.set_attributes(params)
# - check that the attribute was set
print(params["tau"], " == ", mod.tau)
[0.024 0.024 0.024] == [0.024 0.024 0.024]
Functional module resetο
Resetting the module state and parameters also must be done using a functional form:
[9]:
# - Reset the module state
mod = mod.reset_state()
# - Reset the module parameters
mod = mod.reset_parameters()
Jax flatteningο
JaxModule
provides the methods tree_flatten()
and tree_unflatten()
, which are required to serialise and deserialise modules for Jax compilation and execution.
If you write a JaxModule
subclass, it will be automatically registered with Jax as a pytree
. You shouldnβt need to override tree_flatten()
or tree_unflatten()
in your modules.
Flattening and unflattening requires that your __init__()
method must be callable with only a shape as input, which should be sufficient to specify the network architecture of your module and all submodules.
If that isnβt the case, then you may need to override tree_flatten()
and tree_unflatten()
.