{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# 🛠 Low-level `Module` API" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The low-level API in Rockpool is designed for minimal efficient implementation of stateful neural networks.\n", "\n", "The :py:class:`.Module` base class provides facilities for configuring, simulating and examining networks of stateful neurons." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Constructing a `Module`" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "All :py:class:`.Module` subclasses accept minimally a ``shape`` argument on construction. This should specify the input, output and internal dimensionality of the :py:class:`.Module` completely, so that the code can determine how many neurons should be generated, and the sizes of the state variables and parameters.\n", "\n", "Some :py:class:`.Module` subclasses allow you to specify the module shape by setting concrete parameter arrays, e.g. by setting a vector of length ``(N,)`` as the bias parameters for a set of neurons. These concrete parameter values will be used to initialise the :py:class:`.Module`, and if the :py:class:`.Module` is reset, then the parameters will return to those concrete values.\n", "\n", "Otherwise, all :py:class:`.Module` subclasses will set reasonable default initialisation values for the parameters." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Rate  with shape (4,)\n",
       "
\n" ], "text/plain": [ "Rate with shape \u001b[1m(\u001b[0m\u001b[1;36m4\u001b[0m,\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Switch off warnings\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# - Useful imports\n", "try:\n", " from rich import print\n", "except:\n", " pass\n", "\n", "# - Example of constructing a module\n", "from rockpool.nn.modules import Rate\n", "import numpy as np\n", "\n", "# - Construct a Module with 4 neurons\n", "mod = Rate(4)\n", "print(mod)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Rate  with shape (4,)\n",
       "
\n" ], "text/plain": [ "Rate with shape \u001b[1m(\u001b[0m\u001b[1;36m4\u001b[0m,\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Construct a Module with concrete parameters\n", "mod = Rate(4, tau=np.ones(4))\n", "print(mod)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Evolving a `Module`" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "You evolve the state of a :py:class:`.Module` by simply calling it. :py:class:`.Module` subclasses expect clocked raterised data as ``numpy`` arrays with shape ``(T, Nin)`` or ``(batches, T, Nin)``. ``batches`` is the number of batches; ``T`` is the number of time steps, and ``Nin`` is the input size of the module ``mod.size_in``.\n", "\n", "Calling a :py:class:`.Module` has the following syntax: ::\n", "\n", " output, new_state, recorded_state = mod(input: np.array, record: bool = False)\n", " \n", "As a result of calling the :py:class:`.Module`, the output of the module is returned as a ``numpy`` array with shape ``(batches, T, Nout)``. Here ``Nout`` is the output size of the module ``module.size_out``.\n", "\n", "``new_state`` will be a state dictionary containing the final state of the module, and all submodules, at the end of evolution. This will become more relevant when using the functional API (see :py:ref:`/in-depth/api-functional.ipynb`).\n", "\n", "``recorded_state`` is only requested if the argument ``record = True`` is passed to the module. In that case ``recorded_state`` will be a nested dictionary containing the recorded state of the module and all submodules. Each element in ``recorded_state`` should have shape ``(T, ...)``, where ``T`` is the number of evolution timesteps and the following dimensions are whatever appropriate for that state variable." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Output shape: (1, 5, 4)\n",
       "
\n" ], "text/plain": [ "Output shape: \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, \u001b[1;36m5\u001b[0m, \u001b[1;36m4\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Generate and evolve over some input\n", "T = 5\n", "input = np.random.rand(T, mod.size_in)\n", "output, _, _ = mod(input)\n", "print(f\"Output shape: {output.shape}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters:\n",
       "{\n",
       "    'rec_input': array([[[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]]]),\n",
       "    'x': array([[[0.00447242, 0.00335577, 0.00469451, 0.00364505],\n",
       "        [0.00536816, 0.00415351, 0.00567921, 0.00439317],\n",
       "        [0.00544684, 0.00446443, 0.00618463, 0.00441146],\n",
       "        [0.00637465, 0.00499181, 0.00717   , 0.00518635],\n",
       "        [0.00711199, 0.00548246, 0.00750611, 0.00532915]]])\n",
       "}\n",
       "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'rec_input'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'x'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.00447242\u001b[0m, \u001b[1;36m0.00335577\u001b[0m, \u001b[1;36m0.00469451\u001b[0m, \u001b[1;36m0.00364505\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.00536816\u001b[0m, \u001b[1;36m0.00415351\u001b[0m, \u001b[1;36m0.00567921\u001b[0m, \u001b[1;36m0.00439317\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.00544684\u001b[0m, \u001b[1;36m0.00446443\u001b[0m, \u001b[1;36m0.00618463\u001b[0m, \u001b[1;36m0.00441146\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.00637465\u001b[0m, \u001b[1;36m0.00499181\u001b[0m, \u001b[1;36m0.00717\u001b[0m , \u001b[1;36m0.00518635\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0.00711199\u001b[0m, \u001b[1;36m0.00548246\u001b[0m, \u001b[1;36m0.00750611\u001b[0m, \u001b[1;36m0.00532915\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Request the recorded state\n", "output, _, recorded_state = mod(input, record=True)\n", "print(\"Parameters:\", recorded_state)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Parameters, State and SimulationParameters" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Rockpool defines three types of parameters for :py:class:`.Module` s: :py:class:`.Parameter`, :py:class:`.State` and :py:class:`.SimulationParameter`.\n", "\n", ":py:class:`.Parameter` s are roughly any parameter that you would consider part of the configuration of a network. If you need to tell someone else how to specify your network (without going into details of simulation backend), you tell them about your :py:class:`.Parameter` s. Often the set of :py:class:`.Parameter` s will be the trainable parameters of a network.\n", "\n", ":py:class:`.State` s are any internal values that need to be maintained to track how the neurons, synapses, whatever in the dynamical system of a :py:class:`.Module` evolve over time. This could comprise neuron membrane potentials; synaptic currents; etc.\n", "\n", ":py:class:`.SimulationParameter` s are attributes that need to be specified for simulation purposes, but which shouldn't directly affect the network output and behaviour in theory. For example, the time-step ``dt`` of a :py:class:`.Module` is required for a forward Euler ODE solver, but the network configuration should be valid and usable regardless of what ``dt`` is set to. And you shouldn't need to specify the ``dt`` when telling someone else about your network configuration.\n", "\n", "One more useful wrapper class is :py:class:`.Constant`. You should use this to wrap any model parameters that you want to force not to be trainable.\n", "\n", "These classes are defined in :py:mod:`.rockpool.parameters`." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Building a network with `Module` s" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The build a complex network in Rockpool, you need to define your own :py:class:`.Module` subclass. :py:class:`.Module` takes care of many things for you, allowing you to define a network architecture without much overhead.\n", "\n", "Minimally you need to define an :py:meth:`.Module.__init__` method, which specifies network parameters (e.g. weights) and whichever submodules are required for your network. The submodules take over the job of defining their own parameters and states.\n", "\n", "You also need to define an :py:meth:`.Module.evolve` method, which contains the \"plumbing\" of your network. This method specifies how data is passed through your network, between submodules, and out again.\n", "\n", "We'll build a simple FFwd layer containing some weights and a set of neurons.\n", "\n", "Note that this simple example doesn't return the updated module state and recorded state properly." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# - Build a simple network\n", "from rockpool.nn.modules import Module\n", "from rockpool.parameters import Parameter\n", "from rockpool.nn.modules import RateJax\n", "\n", "\n", "class ffwd_net(Module):\n", " # - Provide an `__init__` method to specify required parameters and modules\n", " # Here you check, define and initialise whatever parameters and\n", " # state you need for your module.\n", " def __init__(\n", " self,\n", " shape,\n", " *args,\n", " **kwargs,\n", " ):\n", " # - Call superclass initialisation\n", " # This is always required for a `Module` class\n", " super().__init__(shape=shape, *args, **kwargs)\n", "\n", " # - Specify weights attribute\n", " # We need a weights matrix for our input weights.\n", " # We specify the shape explicitly, and provide an initialisation function.\n", " # We also specify a family for the parameter, \"weights\". This is used to\n", " # query parameters conveniently, and is a good idea to provide.\n", " self.w_ffwd = Parameter(\n", " shape=self.shape,\n", " init_func=lambda s: np.zeros(s),\n", " family=\"weights\",\n", " )\n", "\n", " # - Specify and a add submodule\n", " # These will be the neurons in our layer, to receive the weighted\n", " # input signals. This sub-module will be automatically configured\n", " # internally, to specify the required state and parameters\n", " self.neurons = RateJax(self.shape[-1])\n", "\n", " # - The `evolve` method contains the internal logic of your module\n", " # `evolve` takes care of passing data in and out of the module,\n", " # and between sub-modules if present.\n", " def evolve(self, input_data, *args, **kwargs):\n", " # - Pass input data through the input weights\n", " x = input_data @ self.w_ffwd\n", "\n", " # - Pass the signals through the neurons\n", " x, _, _ = self.neurons(x)\n", "\n", " # - Return the module output\n", " return x, {}, {}" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Writing an `evolve()` method that returns state and record" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "To adhere to the :py:class:`.Module` API, your :py:meth:`.Module.evolve` method must return the updated set of states after evolution, and must support recording internal states if requested. The example below replaces the :py:meth:`.Module.evolve` method for the network above, illustrating how to conveniently do this." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def evolve(self, input_data, record: bool = False, *args, **kwargs):\n", " # - Initialise state and record dictionaries\n", " new_state = {}\n", " recorded_state = {}\n", "\n", " # - Pass input data through the input weights\n", " x = input_data @ self.w_ffwd\n", "\n", " # - Add an internal signal record to the record dictionary\n", " if record:\n", " recorded_state[\"weighted_input\"] = x\n", "\n", " # - Pass the signals through the neurons, passing through the `record` argument\n", " x, submod_state, submod_record = self.neurons(x, record=record)\n", "\n", " # - Record the submodule state\n", " new_state.update(\"neurons\", submod_state)\n", "\n", " # - Include the recorded state\n", " recorded_state.update(\"neurons\", submod_record)\n", "\n", " # - Return the module output\n", " return x, new_state, recorded_state" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Inspecting a `Module`" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "You can examine the internal parameters and state of a :py:class:`.Module` using a set of convenient inspection methods :py:meth:`~.Module.parameters`, :py:meth:`~.Module.state` and :py:meth:`~.Module.simulation_parameters`.\n", "\n", ".. code-block:: python\n", "\n", " params: dict = mod.parameters(family: str = None)\n", " state: dict = mod.state(family: str = None)\n", " simulation_parameters: dict = mod.simulation_parameters(family: str = None)\n", " \n", "In each case the method returns a nested dictionary containins all registered attributes for the module and all submodules." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
ffwd_net  with shape (4, 6) {\n",
       "    RateJax 'neurons' with shape (6,)\n",
       "}\n",
       "
\n" ], "text/plain": [ "ffwd_net with shape \u001b[1m(\u001b[0m\u001b[1;36m4\u001b[0m, \u001b[1;36m6\u001b[0m\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " RateJax \u001b[32m'neurons'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Build a module for our network\n", "my_mod = ffwd_net((4, 6))\n", "print(my_mod)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters:\n",
       "{\n",
       "    'w_ffwd': array([[0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.]]),\n",
       "    'neurons': {\n",
       "        'tau': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32),\n",
       "        'bias': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32),\n",
       "        'threshold': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)\n",
       "    }\n",
       "}\n",
       "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'w_ffwd'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'neurons'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'tau'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'threshold'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Show module parameters\n", "print(\"Parameters:\", my_mod.parameters())" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
State:\n",
       "{\n",
       "    'neurons': {\n",
       "        'rng_key': DeviceArray([1251626347,  511538859], dtype=uint32),\n",
       "        'x': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)\n",
       "    }\n",
       "}\n",
       "
\n" ], "text/plain": [ "State:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'neurons'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'rng_key'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1251626347\u001b[0m, \u001b[1;36m511538859\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint32\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'x'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Show module state\n", "print(\"State:\", my_mod.state())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Module time constants:\n",
       "{'neurons': {'tau': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32)}}\n",
       "
\n" ], "text/plain": [ "Module time constants:\n", "\u001b[1m{\u001b[0m\u001b[32m'neurons'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'tau'\u001b[0m: \u001b[1;35mDeviceArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m, \u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35mfloat32\u001b[0m\u001b[1m)\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Module weights:\n",
       "{\n",
       "    'w_ffwd': array([[0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.]])\n",
       "}\n",
       "
\n" ], "text/plain": [ "Module weights:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'w_ffwd'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Return parameters from particular families\n", "print(\"Module time constants:\", my_mod.parameters(\"taus\"))\n", "print(\"Module weights:\", my_mod.parameters(\"weights\"))" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "You can of course access all attributes of a :py:class:`.Module` directly using standard Python \"dot\" indexing syntax:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
.w_ffwd: [[0. 0. 0. 0. 0. 0.]\n",
       " [0. 0. 0. 0. 0. 0.]\n",
       " [0. 0. 0. 0. 0. 0.]\n",
       " [0. 0. 0. 0. 0. 0.]]\n",
       "
\n" ], "text/plain": [ ".w_ffwd: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m. \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
.neurons.tau: [0.02 0.02 0.02 0.02 0.02 0.02]\n",
       "
\n" ], "text/plain": [ ".neurons.tau: \u001b[1m[\u001b[0m\u001b[1;36m0.02\u001b[0m \u001b[1;36m0.02\u001b[0m \u001b[1;36m0.02\u001b[0m \u001b[1;36m0.02\u001b[0m \u001b[1;36m0.02\u001b[0m \u001b[1;36m0.02\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Access parameters directly\n", "print(\".w_ffwd:\", my_mod.w_ffwd)\n", "print(\".neurons.tau:\", my_mod.neurons.tau)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## `Module` API reference" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Every :py:class:`.Module` provides the following attributes:\n", "\n", "================================== =======================================\n", "Attribute Description\n", "================================== =======================================\n", ":py:attr:`~.Module.class_name` The name of the subclass\n", ":py:attr:`~.Module.name` The attribute name that this module was assigned to. Will be `None` for a base-level module\n", ":py:attr:`~.Module.full_name` The class name and module name together. Useful for printing\n", ":py:attr:`~.Module.spiking_input` If ``True`` this module expects spiking input. Otherwise the input is real-valued\n", ":py:attr:`~.Module.spiking_output` If ``True`` this module produces spiking output. Otherwise the module outputs floating-point values\n", ":py:attr:`~.Module.shape` The dimensions of the module. Can have any number of entries, for complex modules. ``shape[0]`` is the input dimensionality; ``shape[-1]`` is the output dimensionality.\n", ":py:attr:`~.Module.size_in` The number of input channels the module expects\n", ":py:attr:`~.Module.size_out` The number of output channels the module produces\n", "================================== =======================================" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Every :py:class:`.Module` provides the following methods:\n", "\n", "========================================== =======================================\n", "Method Description\n", "========================================== =======================================\n", ":py:meth:`~.Module.parameters` Return a nested dictionary of module parameters, optionally restricting the search to a particular family of parameters such as weights\n", ":py:meth:`~.Module.state` Return a nested dictionary of module state \n", ":py:meth:`~.Module.simulation_parameters` Return a nested dictionary of module simulation parameters \n", ":py:meth:`~.Module.modules` Return a list of submodules of this module\n", "\n", ":py:meth:`~.Module.attributes_named` Search for and return nested attributes matching a particular name\n", "\n", ":py:meth:`~.Module.set_attributes` Set the parameter values for this and nested submodules\n", "\n", ":py:meth:`~.Module.reset_state` Reset the state of this and nested submodules\n", ":py:meth:`~.Module.reset_parameters` Reset the parameters of this and nested submodules to their initialisation defaults\n", "\n", ":py:meth:`~.Module._auto_batch` Utility method to assist with handling batched data\n", "\n", ":py:meth:`~.Module.timed` Convert this module to the high-level :py:class:`TimedModule` API.\n", "========================================== =======================================" ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 4 }