This page was generated from /builds/synsense/rockpool/docs/in-depth/api-functional.ipynb. Interactive online version: Binder badge

[𝝺] Low-level functional API

Rockpool Module s and the JaxModule base class support a functional form for manipulating parameters and for evolution. This is particularly important when using Jax, since this library requires a functional programming style.

Functional evolution

First let’s set up a module to play with:

[1]:
# - Switch off warnings
import warnings

warnings.filterwarnings("ignore")

# - Rockpool imports
from rockpool.nn.modules import RateJax

# - Other useful imports
import numpy as np

try:
    from rich import print
except:
    pass

# - Construct a module
N = 3
mod = RateJax(N)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Could not import package: No module named 'iaf_nest'

Now if we evolve the module, we get the outputs we expect:

[2]:
# - Set up some input
T = 10
input = np.random.rand(T, N)
output, new_state, record = mod(input)
[3]:
print("output:", output)
output: [[[0.02114696 0.00989053 0.03182354]
  [0.04654327 0.03433599 0.06666339]
  [0.08230176 0.05535971 0.07213248]
  [0.12332115 0.05678442 0.08073001]
  [0.13181739 0.09974151 0.07767531]
  [0.13560757 0.09651617 0.09148949]
  [0.1767954  0.10431428 0.11117382]
  [0.21250004 0.1332448  0.12745431]
  [0.23403281 0.15806702 0.12969883]
  [0.23632057 0.19107696 0.1703079 ]]]
[4]:
print("new_state:", new_state)
new_state:
{
    'x': DeviceArray([0.23632057, 0.19107696, 0.1703079 ], dtype=float32),
    'rng_key': array([ 808257824, 3588456238], dtype=uint32)
}

So far so good. The issue with jax is that jit-compiled modules and functions cannot have side-effects. For Rockpool, evolution almost always has side-effects, in terms of updating the internal state variables of each module.

In the case of the evolution above, we can see that the internal state was not updated during evolution:

[5]:
print("mod.state:", mod.state())
print(mod.state()["x"], " != ", new_state["x"])
mod.state:
{
    'rng_key': array([1432838885, 1403173352], dtype=uint32),
    'x': DeviceArray([0., 0., 0.], dtype=float32)
}
[0. 0. 0.]  !=  [0.23632057 0.19107696 0.1703079 ]

The correct resolution to this is to assign new_state to the module atfer each evolution:

[6]:
mod = mod.set_attributes(new_state)
print(mod.state()["x"], " == ", new_state["x"])
[0.23632057 0.19107696 0.1703079 ]  ==  [0.23632057 0.19107696 0.1703079 ]

You will have noticed the functional form of the call to set_attributes() above. This is addressed in the next section.

Functional state and attribute setting

Direct attribute assignment works at the top level, using standard Python syntax:

[7]:
new_tau = mod.tau * 0.4
mod.tau = new_tau
print(new_tau, " == ", mod.tau)
[0.008 0.008 0.008]  ==  [0.008 0.008 0.008]

A functional form is also supported, via the set_attributes() method. Here a copy of the module (and submodules) is returned, to replace the β€œold” module with one with updated attributes:

[8]:
params = mod.parameters()
params["tau"] = params["tau"] * 3.0

# - Note the functional calling style
mod = mod.set_attributes(params)

# - check that the attribute was set
print(params["tau"], " == ", mod.tau)
[0.024 0.024 0.024]  ==  [0.024 0.024 0.024]

Functional module reset

Resetting the module state and parameters also must be done using a functional form:

[9]:
# - Reset the module state
mod = mod.reset_state()

# - Reset the module parameters
mod = mod.reset_parameters()

Jax flattening

JaxModule provides the methods tree_flatten() and tree_unflatten(), which are required to serialise and deserialise modules for Jax compilation and execution.

If you write a JaxModule subclass, it will be automatically registered with Jax as a pytree. You shouldn’t need to override tree_flatten() or tree_unflatten() in your modules.

Flattening and unflattening requires that your __init__() method must be callable with only a shape as input, which should be sufficient to specify the network architecture of your module and all submodules.

If that isn’t the case, then you may need to override tree_flatten() and tree_unflatten().