{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# How To: Configure and perform constrained optimization in Rockpool" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Spiking Neural Networks present a more complex optimisation problem than standard DNNs.\n", "This is due not only to the complex dynamics of spiking neurons but also to the additional classes of parameters present in SNNs.\n", "\n", "DNNs usually optimise linear weights and bias parameters, all of which share a common scale and which can adopt unconstrained finite values.\n", "SNNs, on the other hand, contain various time-constant parameters of various formulations, which can only validly adopt a constrained range of values.\n", "For example, time constants in the form of synaptic and membrane :math:`\\tau`s must be positive.\n", "Decay formulations for synapse and membrane time constants must range ``(0, 1)``.\n", "Firing thresholds are usually also strictly positive values.\n", "\n", "Because of this need, Rockpool provides convenient ways to access individual classes of parameters in a complex network via the :py:meth:`.Module.parameters` interface and a set of tools for easily configuring and imposing boundary constraints during optimisation.\n", "\n", "This How To guide shows you how to use the :py:mod:`.training.torch_loss` and :py:mod:`.training.jax_loss` packages and the features of the :py:mod:`.utilities.tree_utils` mini-library to set up constrained optimisation problems." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.2.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/Shared/anaconda3/envs/py38/lib/python3.8/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n", " PyTreeDef = type(jax.tree_structure(None))\n" ] } ], "source": [ "# - Make sure additional required packages are installed\n", "import sys\n", "!{sys.executable} -m pip install --quiet rich torch jax optax\n", "\n", "from rich import print\n", "\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.figsize'] = [12, 4]\n", "plt.rcParams['figure.dpi'] = 300\n", "\n", "import numpy as np\n", "import torch, jax, optax" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## `torch` interface for constrained optimization" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Rockpool supports both ``torch`` and ``jax`` optimisation backends, with a common API for setting up constrained optimisation.\n", "Here, we demonstrate the ``torch`` interface to set parameter constraints for a single LIF module." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Module: LIFTorch  with shape (1, 1)\n",
                            "
\n" ], "text/plain": [ "Module: LIFTorch with shape \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Parameters:\n",
                            "{\n",
                            "    'tau_mem': Parameter containing:\n",
                            "tensor([0.0200], requires_grad=True),\n",
                            "    'tau_syn': Parameter containing:\n",
                            "tensor([[0.0200]], requires_grad=True),\n",
                            "    'bias': Parameter containing:\n",
                            "tensor([0.], requires_grad=True),\n",
                            "    'threshold': Parameter containing:\n",
                            "tensor([1.], requires_grad=True)\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'tau_mem'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'tau_syn'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Import the LIF module we will use\n", "from rockpool.nn.modules import LIFTorch\n", "\n", "# - Create a single LIF module\n", "net = LIFTorch(1)\n", "print('Module:', net)\n", "print('Parameters:', net.parameters())" ] }, { "attachments": {}, "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Even this single spiking neuron has four parameters --- two time constants for synapse and membrane (``tau_syn`` and ``tau_mem``), which must be positive; a ``bias`` parameter, which can adopt any value; and a ``threshold`` parameter which should also be positive.\n", "\n", "Suppose either time constant becomes negative during training. In that case, the dynamics of the module will be undefined and most likely unstable, leading to a breakdown of both network dynamics and training.\n", "\n", "Rockpool provides a cost function :py:func:`~.training.torch_loss.bounds_cost`, which imposes bounded parameter constraints.\n", "We also provide a helper function :py:func:`~.training.torch_loss.make_bounds`, which helps you build specifications for which parameters should be constrained and how.\n", "\n", "Below we show how the cost function behaves as a parameter approaches and violates a constraint (0, 1)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# - Import the `make_bounds` and `bounds_cost` helper functions\n", "from rockpool.training.torch_loss import make_bounds, bounds_cost\n", "\n", "xs = np.linspace(-1, 2, 1001)\n", "cost = [bounds_cost({'x': torch.tensor(x)}, {'x': 0.}, {'x': 1.}) for x in xs]\n", "\n", "plt.figure()\n", "plt.plot(xs, cost)\n", "plt.plot([0, 0], [0, 3], 'r:')\n", "plt.plot([1, 1], [0, 3], 'r:')\n", "plt.xlabel('Parameter value')\n", "plt.ylabel('Cost');" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "When no bounds are violated, the cost evaluates to zero.\n", "At the bounds, a cost of 1 is imposed, which increases for increasing violations.\n", "\n", "Now let's see how to create and apply bounds to the parameters of a LIF module.\n", "\n", "The :py:func:`.training.torch_loss.make_bounds` function takes the parameters of a Rockpool network and generates lower- and upper-bounds configuration dictionaries.\n", "These dictionaries mimic the structure of the network parameters." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{'tau_mem': -inf, 'tau_syn': -inf, 'bias': -inf, 'threshold': -inf}\n",
                            "{'tau_mem': inf, 'tau_syn': inf, 'bias': inf, 'threshold': inf}\n",
                            "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\u001b[32m'tau_mem'\u001b[0m: -inf, \u001b[32m'tau_syn'\u001b[0m: -inf, \u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m\n", "\u001b[1m{\u001b[0m\u001b[32m'tau_mem'\u001b[0m: inf, \u001b[32m'tau_syn'\u001b[0m: inf, \u001b[32m'bias'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Call `make_bounds` on the parameters of the module\n", "lb, ub = make_bounds(net.parameters())\n", "print(lb, ub)" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "By default, no parameters are constrained --- the lower and upper bounds are set to negative and positive infinity, respectively.\n", "We set bounds by changing the values to finite lower and upper bounds.\n", "Let's use (0ms, 200ms) as the constraints for ``tau_mem``." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{'tau_mem': -inf, 'tau_syn': 0.0, 'bias': -inf, 'threshold': -inf}\n",
                            "{'tau_mem': inf, 'tau_syn': 0.2, 'bias': inf, 'threshold': inf}\n",
                            "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\u001b[32m'tau_mem'\u001b[0m: -inf, \u001b[32m'tau_syn'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m\n", "\u001b[1m{\u001b[0m\u001b[32m'tau_mem'\u001b[0m: inf, \u001b[32m'tau_syn'\u001b[0m: \u001b[1;36m0.2\u001b[0m, \u001b[32m'bias'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lb['tau_syn'] = 0.\n", "ub['tau_syn'] = 200e-3\n", "print(lb, ub)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
tensor(0., grad_fn=<SumBackward0>)\n",
                            "
\n" ], "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mSumBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Evaluate the boundary constraint cost\n", "print(bounds_cost(net.parameters(), lb, ub))" ] }, { "attachments": {}, "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "In this case, no bounds are violated, so the cost is zero.\n", "\n", "Now let's look at an example of a complex network with many layers and module nesting.\n", "We'll define the network to use two different classes of leak parameters, needing different constraints on each class." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Network: TorchSequential  with shape (2, 5) {\n",
                            "    LinearTorch '0_LinearTorch' with shape (2, 3)\n",
                            "    LIFTorch '1_LIFTorch' with shape (3, 3)\n",
                            "    TorchResidual '2_TorchResidual' with shape (3, 3) {\n",
                            "        LinearTorch '0_LinearTorch' with shape (3, 3)\n",
                            "        LIFTorch '1_LIFTorch' with shape (3, 3)\n",
                            "    }\n",
                            "    LinearTorch '3_LinearTorch' with shape (3, 5)\n",
                            "    LIFTorch '4_LIFTorch' with shape (5, 5)\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "Network: TorchSequential with shape \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m5\u001b[0m\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " LinearTorch \u001b[32m'0_LinearTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", " LIFTorch \u001b[32m'1_LIFTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", " TorchResidual \u001b[32m'2_TorchResidual'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " LinearTorch \u001b[32m'0_LinearTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", " LIFTorch \u001b[32m'1_LIFTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", " LinearTorch \u001b[32m'3_LinearTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m5\u001b[0m\u001b[1m)\u001b[0m\n", " LIFTorch \u001b[32m'4_LIFTorch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m5\u001b[0m, \u001b[1;36m5\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Parameters:\n",
                            "{\n",
                            "    '0_LinearTorch': {\n",
                            "        'weight': Parameter containing:\n",
                            "tensor([[-1.7288,  0.3565,  0.7264],\n",
                            "        [-0.9882, -1.2147,  1.2812]], requires_grad=True)\n",
                            "    },\n",
                            "    '1_LIFTorch': {\n",
                            "        'alpha': Parameter containing:\n",
                            "tensor([0.5000, 0.5000, 0.5000], requires_grad=True),\n",
                            "        'beta': Parameter containing:\n",
                            "tensor([[0.5000],\n",
                            "        [0.5000],\n",
                            "        [0.5000]], requires_grad=True),\n",
                            "        'bias': Parameter containing:\n",
                            "tensor([0., 0., 0.], requires_grad=True),\n",
                            "        'threshold': Parameter containing:\n",
                            "tensor([1., 1., 1.], requires_grad=True)\n",
                            "    },\n",
                            "    '2_TorchResidual': {\n",
                            "        '0_LinearTorch': {\n",
                            "            'weight': Parameter containing:\n",
                            "tensor([[ 0.7141, -1.3781, -0.7695],\n",
                            "        [-0.8757, -0.6188, -0.4058],\n",
                            "        [-0.8914,  0.4774,  0.1480]], requires_grad=True)\n",
                            "        },\n",
                            "        '1_LIFTorch': {\n",
                            "            'alpha': Parameter containing:\n",
                            "tensor([0.5000, 0.5000, 0.5000], requires_grad=True),\n",
                            "            'beta': Parameter containing:\n",
                            "tensor([[0.5000],\n",
                            "        [0.5000],\n",
                            "        [0.5000]], requires_grad=True),\n",
                            "            'bias': Parameter containing:\n",
                            "tensor([0., 0., 0.], requires_grad=True),\n",
                            "            'threshold': Parameter containing:\n",
                            "tensor([1., 1., 1.], requires_grad=True)\n",
                            "        }\n",
                            "    },\n",
                            "    '3_LinearTorch': {\n",
                            "        'weight': Parameter containing:\n",
                            "tensor([[ 0.4936,  1.3029, -1.0018,  1.1015, -1.4031],\n",
                            "        [ 0.1428,  1.2207,  1.1742,  0.8693, -0.3616],\n",
                            "        [ 0.1114, -1.3645, -1.3504, -0.6766, -0.4882]], requires_grad=True)\n",
                            "    },\n",
                            "    '4_LIFTorch': {\n",
                            "        'tau_mem': Parameter containing:\n",
                            "tensor([0.0200, 0.0200, 0.0200, 0.0200, 0.0200], requires_grad=True),\n",
                            "        'tau_syn': Parameter containing:\n",
                            "tensor([[0.0200],\n",
                            "        [0.0200],\n",
                            "        [0.0200],\n",
                            "        [0.0200],\n",
                            "        [0.0200]], requires_grad=True),\n",
                            "        'bias': Parameter containing:\n",
                            "tensor([0., 0., 0., 0., 0.], requires_grad=True),\n",
                            "        'threshold': Parameter containing:\n",
                            "tensor([1., 1., 1., 1., 1.], requires_grad=True)\n",
                            "    }\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-1.7288\u001b[0m, \u001b[1;36m0.3565\u001b[0m, \u001b[1;36m0.7264\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.9882\u001b[0m, \u001b[1;36m-1.2147\u001b[0m, \u001b[1;36m1.2812\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'alpha'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'beta'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'2_TorchResidual'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.7141\u001b[0m, \u001b[1;36m-1.3781\u001b[0m, \u001b[1;36m-0.7695\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.8757\u001b[0m, \u001b[1;36m-0.6188\u001b[0m, \u001b[1;36m-0.4058\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.8914\u001b[0m, \u001b[1;36m0.4774\u001b[0m, \u001b[1;36m0.1480\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'alpha'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'beta'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.4936\u001b[0m, \u001b[1;36m1.3029\u001b[0m, \u001b[1;36m-1.0018\u001b[0m, \u001b[1;36m1.1015\u001b[0m, \u001b[1;36m-1.4031\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1428\u001b[0m, \u001b[1;36m1.2207\u001b[0m, \u001b[1;36m1.1742\u001b[0m, \u001b[1;36m0.8693\u001b[0m, \u001b[1;36m-0.3616\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1114\u001b[0m, \u001b[1;36m-1.3645\u001b[0m, \u001b[1;36m-1.3504\u001b[0m, \u001b[1;36m-0.6766\u001b[0m, \u001b[1;36m-0.4882\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'4_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'tau_mem'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m, \u001b[1;36m0.0200\u001b[0m, \u001b[1;36m0.0200\u001b[0m, \u001b[1;36m0.0200\u001b[0m, \u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'tau_syn'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.0200\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from rockpool.nn.modules import LinearTorch\n", "from rockpool.nn.combinators import Sequential, Residual\n", "\n", "net = Sequential(\n", " LinearTorch((2, 3)),\n", " LIFTorch(3, leak_mode=\"decays\"),\n", "\n", " Residual(\n", " LinearTorch((3, 3)),\n", " LIFTorch(3, leak_mode=\"decays\"),\n", " ),\n", "\n", " LinearTorch((3, 5)),\n", " LIFTorch(5, leak_mode=\"taus\"),\n", ")\n", "print('Network:', net)\n", "print('Parameters:', net.parameters())" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "This is a deeply nested network with a complex set of parameters.\n", "Luckily, Rockpool provides several convenient tools that make it easy to build constraints even for complex networks.\n", "\n", "The :py:meth:`.Module.parameters` method allows you to easily extract families of parameters, helping you identify all time constants, for example.\n", "The :py:meth:`.Module.attributes_named` method allows you to specify particular named parameters.\n", "\n", "The mini-library :py:mod:`~.rockpool.utilities.tree_utils` helps you easily manipulate the parameter and constraint dictionaries to set chosen bounds.\n", "Here we'll use the :py:func:`.tree_utils.set_matching` function to set bounds for chosen parameter sets." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{\n",
                            "    '0_LinearTorch': {'weight': -inf},\n",
                            "    '1_LIFTorch': {'alpha': 0.5, 'beta': 0.5, 'bias': -inf, 'threshold': -inf},\n",
                            "    '2_TorchResidual': {\n",
                            "        '0_LinearTorch': {'weight': -inf},\n",
                            "        '1_LIFTorch': {'alpha': 0.5, 'beta': 0.5, 'bias': -inf, 'threshold': -inf}\n",
                            "    },\n",
                            "    '3_LinearTorch': {'weight': -inf},\n",
                            "    '4_LIFTorch': {'tau_mem': 0.0, 'tau_syn': 0.0, 'bias': -inf, 'threshold': -inf}\n",
                            "}\n",
                            "{\n",
                            "    '0_LinearTorch': {'weight': inf},\n",
                            "    '1_LIFTorch': {'alpha': 1.0, 'beta': 1.0, 'bias': inf, 'threshold': inf},\n",
                            "    '2_TorchResidual': {\n",
                            "        '0_LinearTorch': {'weight': inf},\n",
                            "        '1_LIFTorch': {'alpha': 1.0, 'beta': 1.0, 'bias': inf, 'threshold': inf}\n",
                            "    },\n",
                            "    '3_LinearTorch': {'weight': inf},\n",
                            "    '4_LIFTorch': {'tau_mem': inf, 'tau_syn': 0.5, 'bias': inf, 'threshold': inf}\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'alpha'\u001b[0m: \u001b[1;36m0.5\u001b[0m, \u001b[32m'beta'\u001b[0m: \u001b[1;36m0.5\u001b[0m, \u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'2_TorchResidual'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'alpha'\u001b[0m: \u001b[1;36m0.5\u001b[0m, \u001b[32m'beta'\u001b[0m: \u001b[1;36m0.5\u001b[0m, \u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'4_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'tau_mem'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'tau_syn'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'alpha'\u001b[0m: \u001b[1;36m1.0\u001b[0m, \u001b[32m'beta'\u001b[0m: \u001b[1;36m1.0\u001b[0m, \u001b[32m'bias'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'2_TorchResidual'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'alpha'\u001b[0m: \u001b[1;36m1.0\u001b[0m, \u001b[32m'beta'\u001b[0m: \u001b[1;36m1.0\u001b[0m, \u001b[32m'bias'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LinearTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'4_LIFTorch'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'tau_mem'\u001b[0m: inf, \u001b[32m'tau_syn'\u001b[0m: \u001b[1;36m0.5\u001b[0m, \u001b[32m'bias'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Import the tree utilities library\n", "import rockpool.utilities.tree_utils as tu\n", "\n", "# - Make template lower and upper bounds\n", "lb, ub = make_bounds(net.parameters())\n", "\n", "# - Set lower bounds on \"decays\" and \"taus\" family parameters\n", "lb = tu.set_matching(lb, net.parameters('decays'), 0.5)\n", "lb = tu.set_matching(lb, net.parameters('taus'), 0.)\n", "\n", "# - Set upper bounds on \"decays\" family parameters\n", "ub = tu.set_matching(ub, net.parameters('decays'), 1.)\n", "\n", "# - Set an upper bound on a specific parameter name\n", "ub = tu.set_matching(ub, net.attributes_named('tau_syn'), 500e-3)\n", "\n", "print(lb, ub)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
tensor(0., grad_fn=<SumBackward0>)\n",
                            "
\n" ], "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mSumBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Evaluate the boundary constraint cost for the full set of network parameters\n", "print(bounds_cost(net.parameters(), lb, ub))" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Defining and evaluating boundary losses for constrained optimisation is made simple, even for complex networks!\n", "Imposing the constraints is as simple as including :py:func:`.torch_loss.bounds_cost` as a factor of the loss function during training, as in the example below." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from torch.optim import Adam\n", "from torch.nn import CrossEntropyLoss\n", "\n", "# - Initialise the optimiser\n", "optimizer = Adam(net.parameters().astorch(), lr=1e-3)\n", "func_loss = CrossEntropyLoss()\n", "\n", "# - Dummy dataset\n", "dataset = [(torch.tensor(np.random.rand(1, 1, 2), dtype=torch.float), torch.tensor(np.random.rand(1, 1, 5), dtype=torch.float))]\n", "\n", "# - Optimiser loop over dataset\n", "for input, target in dataset:\n", " optimizer.zero_grad()\n", " output, _, _ = net(input)\n", "\n", " # - Evaluate the functional and constraints losses\n", " loss = func_loss(output, target) + bounds_cost(net.parameters(), lb, ub)\n", "\n", " # - Perform the backward step\n", " loss.backward()\n", " optimizer.step()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## `jax` interface for constrained optimization" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The ``jax`` interface for constrained optimisation is identical to the ``torch`` interface.\n", "Here we demonstrate a similar constrained optimisation problem as above." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# - Import the Rockpool NN modules\n", "from rockpool.nn.modules import LIFJax, LinearJax\n", "from rockpool.nn.combinators import Sequential\n", "\n", "# - Import tools from ``jax_loss`` instead of ``torch_loss``\n", "from rockpool.training.jax_loss import make_bounds, bounds_cost\n", "\n", "# - Import the tree utility package\n", "from rockpool.utilities import tree_utils as tu" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Network: JaxSequential  with shape (2, 5) {\n",
                            "    LinearJax '0_LinearJax' with shape (2, 3)\n",
                            "    LIFJax '1_LIFJax' with shape (3, 3)\n",
                            "    LinearJax '2_LinearJax' with shape (3, 5)\n",
                            "    LIFJax '3_LIFJax' with shape (5, 5)\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "Network: JaxSequential with shape \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m5\u001b[0m\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " LinearJax \u001b[32m'0_LinearJax'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", " LIFJax \u001b[32m'1_LIFJax'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", " LinearJax \u001b[32m'2_LinearJax'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m5\u001b[0m\u001b[1m)\u001b[0m\n", " LIFJax \u001b[32m'3_LIFJax'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m5\u001b[0m, \u001b[1;36m5\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Parameters:\n",
                            "{\n",
                            "    '0_LinearJax': {\n",
                            "        'weight': array([[ 1.45853284, -1.09957019, -0.31666267],\n",
                            "       [ 0.63834515,  1.47955107,  1.03358989]])\n",
                            "    },\n",
                            "    '1_LIFJax': {\n",
                            "        'tau_mem': DeviceArray([0.02, 0.02, 0.02], dtype=float32),\n",
                            "        'tau_syn': DeviceArray([[0.02],\n",
                            "             [0.02],\n",
                            "             [0.02]], dtype=float32),\n",
                            "        'bias': DeviceArray([0., 0., 0.], dtype=float32),\n",
                            "        'threshold': DeviceArray([1., 1., 1.], dtype=float32)\n",
                            "    },\n",
                            "    '2_LinearJax': {\n",
                            "        'weight': array([[ 0.3648217 ,  0.34105733,  1.23947428, -0.42756732,  0.6361447 ],\n",
                            "       [-0.19837451,  0.61290813,  1.25214626,  1.206278  ,  0.70346237],\n",
                            "       [ 0.73964399, -1.02753273, -0.28541291, -1.10618743,  0.78135608]])\n",
                            "    },\n",
                            "    '3_LIFJax': {\n",
                            "        'tau_mem': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32),\n",
                            "        'tau_syn': DeviceArray([[0.02],\n",
                            "             [0.02],\n",
                            "             [0.02],\n",
                            "             [0.02],\n",
                            "             [0.02]], dtype=float32),\n",
                            "        'bias': DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n",
                            "        'threshold': DeviceArray([1., 1., 1., 1., 1.], dtype=float32)\n",
                            "    }\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m1.45853284\u001b[0m, \u001b[1;36m-1.09957019\u001b[0m, \u001b[1;36m-0.31666267\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.63834515\u001b[0m, \u001b[1;36m1.47955107\u001b[0m, \u001b[1;36m1.03358989\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'tau_mem'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'tau_syn'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'2_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.3648217\u001b[0m , \u001b[1;36m0.34105733\u001b[0m, \u001b[1;36m1.23947428\u001b[0m, \u001b[1;36m-0.42756732\u001b[0m, \u001b[1;36m0.6361447\u001b[0m \u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.19837451\u001b[0m, \u001b[1;36m0.61290813\u001b[0m, \u001b[1;36m1.25214626\u001b[0m, \u001b[1;36m1.206278\u001b[0m , \u001b[1;36m0.70346237\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.73964399\u001b[0m, \u001b[1;36m-1.02753273\u001b[0m, \u001b[1;36m-0.28541291\u001b[0m, \u001b[1;36m-1.10618743\u001b[0m, \u001b[1;36m0.78135608\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'tau_mem'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'tau_syn'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m., \u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Set up a simple network\n", "net = Sequential(\n", " LinearJax((2, 3)),\n", " LIFJax(3),\n", " LinearJax((3, 5)),\n", " LIFJax(5),\n", ")\n", "print('Network:', net)\n", "print('Parameters:', net.parameters())" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "We again use :py:func:`.training.jax_loss.make_bounds` to build a template configuration for constrained optimisation.\n", "We use the tree handling library and :py:meth:`.Module.parameters`, to set lower-bounds constraints on time constants." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{\n",
                            "    '0_LinearJax': {'weight': -inf},\n",
                            "    '1_LIFJax': {'bias': -inf, 'tau_mem': -inf, 'tau_syn': -inf, 'threshold': -inf},\n",
                            "    '2_LinearJax': {'weight': -inf},\n",
                            "    '3_LIFJax': {'bias': -inf, 'tau_mem': -inf, 'tau_syn': -inf, 'threshold': -inf}\n",
                            "}\n",
                            "{\n",
                            "    '0_LinearJax': {'weight': inf},\n",
                            "    '1_LIFJax': {'bias': inf, 'tau_mem': inf, 'tau_syn': inf, 'threshold': inf},\n",
                            "    '2_LinearJax': {'weight': inf},\n",
                            "    '3_LIFJax': {'bias': inf, 'tau_mem': inf, 'tau_syn': inf, 'threshold': inf}\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'tau_mem'\u001b[0m: -inf, \u001b[32m'tau_syn'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'2_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'tau_mem'\u001b[0m: -inf, \u001b[32m'tau_syn'\u001b[0m: -inf, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'bias'\u001b[0m: inf, \u001b[32m'tau_mem'\u001b[0m: inf, \u001b[32m'tau_syn'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'2_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'bias'\u001b[0m: inf, \u001b[32m'tau_mem'\u001b[0m: inf, \u001b[32m'tau_syn'\u001b[0m: inf, \u001b[32m'threshold'\u001b[0m: inf\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Build a template configuration\n", "lb, ub = make_bounds(net.parameters())\n", "print(lb, ub)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{\n",
                            "    '0_LinearJax': {'weight': -inf},\n",
                            "    '1_LIFJax': {'bias': -inf, 'tau_mem': 0.0, 'tau_syn': 0.0, 'threshold': -inf},\n",
                            "    '2_LinearJax': {'weight': -inf},\n",
                            "    '3_LIFJax': {'bias': -inf, 'tau_mem': 0.0, 'tau_syn': 0.0, 'threshold': -inf}\n",
                            "}\n",
                            "
\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[32m'0_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'1_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'tau_mem'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'tau_syn'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'2_LinearJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'weight'\u001b[0m: -inf\u001b[1m}\u001b[0m,\n", " \u001b[32m'3_LIFJax'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'bias'\u001b[0m: -inf, \u001b[32m'tau_mem'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'tau_syn'\u001b[0m: \u001b[1;36m0.0\u001b[0m, \u001b[32m'threshold'\u001b[0m: -inf\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Set lower bounds for time constants\n", "lb = tu.set_matching(lb, net.parameters('taus'), 0.)\n", "print(lb)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "raw_mimetype": "text/restructuredtext" }, "outputs": [ { "data": { "text/html": [ "
0.0\n",
                            "
\n" ], "text/plain": [ "\u001b[1;36m0.0\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Evaluate the boundary constraint cost\n", "print(bounds_cost(net.parameters(), lb, ub))" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Therefore, the Rockpool-provided interface for setting bounds is almost identical between ``torch`` and ``jax``.\n", "Below we show a very simple ``jax`` optimisation loop that incorporates the boundary constraints during optimisation." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# - Initialise the Adam optimiser with the initial network parameters\n", "optimizer = optax.adam(1e-4)\n", "params = net.parameters()\n", "opt_state = optimizer.init(params)\n", "\n", "# - Use an MSE loss\n", "func_loss = lambda o, t: jax.numpy.mean((o - t) ** 2)\n", "\n", "# - Network evaluation and loss function\n", "def eval_loss(params, net, input, target):\n", " output, _, _ = net(input)\n", " loss = func_loss(output, target) + bounds_cost(params, lb, ub)\n", "\n", " return loss\n", "\n", "# - Dummy dataset\n", "dataset = [(np.random.rand(1, 1, 2), np.random.rand(1, 1, 5))]\n", "\n", "# - Loop over dataset, evaluating loss and applying updates\n", "for input, target in dataset:\n", " loss_value, grads = jax.value_and_grad(eval_loss)(params, net, input, target)\n", " updates, opt_state = optimizer.update(grads, opt_state, params)\n", " params = optax.apply_updates(params, updates)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Next steps" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "See :ref:`/in-depth/jax-training.ipynb` for a ``jax`` training example that includes constraints." ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }