{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# [𝝺] Low-level functional API" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Rockpool :py:class:`.Module` s and the :py:class:`.JaxModule` base class support a functional form for manipulating parameters and for evolution. This is particularly important when using Jax, since this library requires a functional programming style." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Functional evolution" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "First let's set up a module to play with:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# - Switch off warnings\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# - Rockpool imports\n", "from rockpool.nn.modules import RateJax\n", "\n", "# - Other useful imports\n", "import numpy as np\n", "\n", "try:\n", " from rich import print\n", "except:\n", " pass\n", "\n", "# - Construct a module\n", "N = 3\n", "mod = RateJax(N)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now if we evolve the module, we get the outputs we expect:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# - Set up some input\n", "T = 10\n", "input = np.random.rand(T, N)\n", "output, new_state, record = mod(input)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
output: [[[0.01076346 0.02204564 0.02956902]\n",
       "  [0.04010231 0.06479559 0.07025937]\n",
       "  [0.0432203  0.06293815 0.11122491]\n",
       "  [0.06027619 0.08317991 0.1307195 ]\n",
       "  [0.05882018 0.09205481 0.15850767]\n",
       "  [0.08602361 0.12148949 0.16591538]\n",
       "  [0.10797263 0.11710763 0.18505278]\n",
       "  [0.13634151 0.13502166 0.20233528]\n",
       "  [0.15542242 0.15269144 0.20374872]\n",
       "  [0.16027772 0.15767853 0.19997193]]]\n",
       "
\n" ], "text/plain": [ "output: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.01076346\u001b[0m \u001b[1;36m0.02204564\u001b[0m \u001b[1;36m0.02956902\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.04010231\u001b[0m \u001b[1;36m0.06479559\u001b[0m \u001b[1;36m0.07025937\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.0432203\u001b[0m \u001b[1;36m0.06293815\u001b[0m \u001b[1;36m0.11122491\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.06027619\u001b[0m \u001b[1;36m0.08317991\u001b[0m \u001b[1;36m0.1307195\u001b[0m \u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.05882018\u001b[0m \u001b[1;36m0.09205481\u001b[0m \u001b[1;36m0.15850767\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.08602361\u001b[0m \u001b[1;36m0.12148949\u001b[0m \u001b[1;36m0.16591538\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.10797263\u001b[0m \u001b[1;36m0.11710763\u001b[0m \u001b[1;36m0.18505278\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.13634151\u001b[0m \u001b[1;36m0.13502166\u001b[0m \u001b[1;36m0.20233528\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.15542242\u001b[0m \u001b[1;36m0.15269144\u001b[0m \u001b[1;36m0.20374872\u001b[0m\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.16027772\u001b[0m \u001b[1;36m0.15767853\u001b[0m \u001b[1;36m0.19997193\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(\"output:\", output)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
new_state:\n",
       "{\n",
       "    'x': DeviceArray([0.16027772, 0.15767853, 0.19997193], dtype=float32),\n",
       "    'rng_key': DeviceArray([2469880657, 3700232383], dtype=uint32)\n",
       "}\n",
       "
\n" ], "text/plain": [ "new_state:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'x'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.16027772\u001b[0m, \u001b[1;36m0.15767853\u001b[0m, \u001b[1;36m0.19997193\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'rng_key'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m2469880657\u001b[0m, \u001b[1;36m3700232383\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint32\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(\"new_state:\", new_state)" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "So far so good. The issue with ``jax`` is that ``jit``-compiled modules and functions cannot have side-effects. For Rockpool, evolution almost *always* has side-effects, in terms of updating the internal state variables of each module.\n", "\n", "In the case of the evolution above, we can see that the internal state was not updated during evolution:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
mod.state:\n",
       "{\n",
       "    'rng_key': DeviceArray([ 237268104, 2681681569], dtype=uint32),\n",
       "    'x': DeviceArray([0., 0., 0.], dtype=float32)\n",
       "}\n",
       "
\n" ], "text/plain": [ "mod.state:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'rng_key'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m237268104\u001b[0m, \u001b[1;36m2681681569\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'x'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
[0. 0. 0.]  !=  [0.16027772 0.15767853 0.19997193]\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m != \u001b[1m[\u001b[0m\u001b[1;36m0.16027772\u001b[0m \u001b[1;36m0.15767853\u001b[0m \u001b[1;36m0.19997193\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(\"mod.state:\", mod.state())\n", "print(mod.state()[\"x\"], \" != \", new_state[\"x\"])" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The correct resolution to this is to assign ``new_state`` to the module atfer each evolution:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[0.16027772 0.15767853 0.19997193]  ==  [0.16027772 0.15767853 0.19997193]\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[1;36m0.16027772\u001b[0m \u001b[1;36m0.15767853\u001b[0m \u001b[1;36m0.19997193\u001b[0m\u001b[1m]\u001b[0m == \u001b[1m[\u001b[0m\u001b[1;36m0.16027772\u001b[0m \u001b[1;36m0.15767853\u001b[0m \u001b[1;36m0.19997193\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mod = mod.set_attributes(new_state)\n", "print(mod.state()[\"x\"], \" == \", new_state[\"x\"])" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "You will have noticed the functional form of the call to :py:meth:`~.JaxModule.set_attributes` above. This is addressed in the next section." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Functional state and attribute setting" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Direct attribute assignment works at the top level, using standard Python syntax:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[0.008 0.008 0.008]  ==  [0.008 0.008 0.008]\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[1;36m0.008\u001b[0m \u001b[1;36m0.008\u001b[0m \u001b[1;36m0.008\u001b[0m\u001b[1m]\u001b[0m == \u001b[1m[\u001b[0m\u001b[1;36m0.008\u001b[0m \u001b[1;36m0.008\u001b[0m \u001b[1;36m0.008\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "new_tau = mod.tau * 0.4\n", "mod.tau = new_tau\n", "print(new_tau, \" == \", mod.tau)" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "A functional form is also supported, via the :py:meth:`~JaxModule.set_attributes` method. Here a copy of the module (and submodules) is returned, to replace the \"old\" module with one with updated attributes:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[0.024 0.024 0.024]  ==  [0.024 0.024 0.024]\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[1;36m0.024\u001b[0m \u001b[1;36m0.024\u001b[0m \u001b[1;36m0.024\u001b[0m\u001b[1m]\u001b[0m == \u001b[1m[\u001b[0m\u001b[1;36m0.024\u001b[0m \u001b[1;36m0.024\u001b[0m \u001b[1;36m0.024\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "params = mod.parameters()\n", "params[\"tau\"] = params[\"tau\"] * 3.0\n", "\n", "# - Note the functional calling style\n", "mod = mod.set_attributes(params)\n", "\n", "# - check that the attribute was set\n", "print(params[\"tau\"], \" == \", mod.tau)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Functional module reset" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Resetting the module state and parameters also must be done using a functional form:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# - Reset the module state\n", "mod = mod.reset_state()\n", "\n", "# - Reset the module parameters\n", "mod = mod.reset_parameters()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Jax flattening" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ":py:class:`.JaxModule` provides the methods :py:meth:`~.JaxModule.tree_flatten` and :py:meth:`~.JaxModule.tree_unflatten`, which are required to serialise and deserialise modules for Jax compilation and execution.\n", "\n", "If you write a :py:class:`.JaxModule` subclass, it will be automatically registered with Jax as a ``pytree``. You shouldn't need to override :py:meth:`~.JaxModule.tree_flatten` or :py:meth:`~.JaxModule.tree_unflatten` in your modules.\n", "\n", "Flattening and unflattening requires that your :py:meth:`~.JaxModule.__init__` method must be callable with only a shape as input, which should be sufficient to specify the network architecture of your module and all submodules.\n", "\n", "If that isn't the case, then you may need to override :py:meth:`~.JaxModule.tree_flatten` and :py:meth:`~.JaxModule.tree_unflatten`." ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 4 }