{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Graph mapping" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "You can extend the computational graphing capabilities of Rockpool, by adding new :py:class:`.graph.GraphModule` subclasses. These classes can be converted between eachother and the graph can be analysed in order to map networks on to hardware." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Subclassing `GraphModule`" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "All :py:class:`~.graph.GraphModule` classes are :py:class:`dataclass` es, and use the ``@dataclass`` decorator. As below, you must use ``@dataclass(eq = False, repr = False)`` to decorate your subclass, in order to be compatible with the graph mapping subsystem.\n", "\n", "The subsystem requires that equality is defined by object ID (hence ``eq = False``), and provides a human-readable :py:meth:`~.graph.GraphModule.__repr__` method (hence ``repr = False``, to avoid using the ``dataclass`` ``__repr__()`` method)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# - Switch off warnings\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# - Rockpool imports\n", "import rockpool.graph as rg\n", "from dataclasses import dataclass\n", "\n", "\n", "@dataclass(eq=False, repr=False)\n", "class MyGraphModule(rg.GraphModule):\n", " # - Define parameters as for any dataclass\n", " param1: float\n", " param2: int\n", " param3: list" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ":py:class:`~.graph.GraphModule` provides a :py:meth:`~.graph.GraphModule.__post_init__` method that can be used to perform any validity checks after initialisation. :py:meth:`~.graph.GraphModule.__post_init__` also ensures that the :py:attr:`~.graph.GraphModule.input_nodes` and :py:attr:`~.graph.GraphModule.output_nodes` are correctly connected to the module being created.\n", "\n", "If you override :py:meth:`~.graph.GraphModule.__post_init__`, you *must* call ``super().__post_init__``." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class MyGraphModule(MyGraphModule):\n", " # - Any initialisation checks can be performed a __post_init__ method\n", " def __post_init__(self, *args, **kwargs):\n", " # - You *must* call super().__post_init__()\n", " super().__post_init__(*args, **kwargs)\n", "\n", " if param1 < param2:\n", " raise ValueError(\"param1 must be > param2\")" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ":py:class:`~.graph.GraphModule` provides several methods:\n", "\n", "================================================== =================\n", "Method Purpose\n", "================================================== =================\n", ":py:meth:`~.graph.GraphModule._factory` Factory method to instantiate an object with self-created input and output nodes\n", ":py:meth:`~.graph.GraphModule.__post_init__` Perform any post-initialisation checks on the module\n", ":py:meth:`~.graph.GraphModule.add_input` Add a :py:class:`~.graph.GraphNode` as an input of this module\n", ":py:meth:`~.graph.GraphModule.add_output` Add a :py:class:`~.graph.GraphNode` as an output of this module\n", ":py:meth:`~.graph.GraphModule.remove_input` Remove a :py:class:`~.graph.GraphNode` as an input of this module\n", ":py:meth:`~.graph.GraphModule.remove_output` Remove a :py:class:`~.graph.GraphNode` as an output of this module\n", ":py:meth:`~.graph.GraphModule._convert_from` Class method: Try to convert a different a :py:class:`~.graph.GraphModule` to an object of the current subclass\n", "================================================== =================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transforming `GraphModule` s" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ":py:class:`~.graph.GraphModule` provides a method :py:meth:`~.graph.GraphModule._convert_from`, which is used to transform :py:class:`~.graph.GraphModule` objects between various subclasses. These conversion rules must be specifically defined --- there is no real automatic conversion between classes. If you do not override :py:meth:`~.graph.GraphModule._convert_from` then you will not be able to convert other :py:class:`~.graph.GraphModule` subclasses to objects of your class.\n", "\n", "Below is an example implementation of :py:meth:`~.graph.GraphModule._convert_from`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import rockpool.graph as rg\n", "from typing import List\n", "from dataclasses import dataclass\n", "\n", "\n", "@dataclass(eq=False, repr=False)\n", "class MyNeurons(rg.GenericNeurons):\n", " thresholds: List[int]\n", " dt: float\n", "\n", " @classmethod\n", " def _convert_from(cls, mod: rg.GraphModule) -> rg.GraphModule:\n", " if isinstance(mod, cls):\n", " # - No need to do anything\n", " return mod\n", "\n", " elif isinstance(mod, LIFNeuronWithSynsRealValue):\n", " # - Convert from a real-valued LIF neuron\n", " # - Get a value for `dt` to use in the conversion\n", " if mod.dt is None:\n", " raise ValueError(\n", " f\"Graph module of type {type(mod).__name__} has no `dt` set, so cannot convert time constants when converting to {cls.__name__}.\"\n", " )\n", "\n", " # - Get thresholds from source module\n", " thresholds = np.round(np.array(mod.threshold)).astype(int).tolist()\n", "\n", " # - Build a new self module to insert into the graph\n", " neurons = cls._factory(\n", " len(mod.input_nodes),\n", " len(mod.output_nodes),\n", " mod.name,\n", " thresholds,\n", " mod.dt,\n", " )\n", "\n", " # - Replace the target module and return\n", " rg.replace_module(mod, neurons)\n", " return neurons\n", "\n", " else:\n", " raise ValueError(\n", " f\"Graph module of type {type(mod).__name__} cannot be converted to a {cls.__name__}\"\n", " )" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "In the example above, the rules match specific subclasses of :py:class:`~.graph.GraphModule`, and convert them explicitly." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a mapper" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The steps in mapping a graph onto some target hardware are generally\n", "\n", " - Check DRC rules (this permits you to make many assumptions about the graph structure once DRC rules pass)\n", " - Convert neuron graph module types to types that match the hardware\n", " - Assign hardware IDs to neurons, weights, inputs, outputs\n", " - Pull required data from the graph and build an equivalent hardware configuration\n", " \n", "Currently there is a mapper for the Xylo architecture in :py:func:`.devices.xylo.mapper`. Look through the code there for an example of building a mapper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DRC checks" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The suggested way to perform DRC checks is to write a set of functions, each of which defines a single design rule, as an check over a graph. If the design rule is violated, then you should raise an error.\n", "\n", "Below are examples of a few design rules." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import rockpool.graph as rg\n", "from typing import List, Callable\n", "\n", "# - Define an error class for DRC violations\n", "class DRCError(ValueError):\n", " pass\n", "\n", "\n", "def output_nodes_have_neurons_as_source(graph: rg.GraphModule):\n", " # - All output nodes must have a source that is a neuron\n", " for n in graph.output_nodes:\n", " for s in n.source_modules:\n", " if not isinstance(s, rg.GenericNeurons):\n", " raise DRCError(\n", " f\"All network outputs must be directly from neurons.\\nA network output node {n} has a source {s} which is not a neuron.\"\n", " )\n", "\n", "\n", "def first_module_is_a_weight(graph: rg.GraphModule):\n", " # - The first module after the input must be a set of weights\n", " for inp in graph.input_nodes:\n", " for sink in inp.sink_modules:\n", " if not isinstance(sink, rg.LinearWeights):\n", " raise DRCError(\n", " f\"The network input must go first through a weight.\\nA network input node {inp} has a sink module {sink} which is not a LinearWeight.\"\n", " )\n", "\n", "\n", "def le_16_input_channels(graph: rg.GraphModule):\n", " # - Only 16 input channels are supported\n", " if len(graph.input_nodes) > 16:\n", " raise DRCError(\n", " f\"Xylo only supports up to 16 input channels. The network requires {len(graph.input_nodes)} input channels.\"\n", " )" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Now we show a suggested way to collect the rules and perform a DRC." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# - Collect a list of DRC functions\n", "xylo_drc = [\n", " output_nodes_have_neurons_as_source,\n", " first_module_is_a_weight,\n", " le_16_input_channels,\n", "]\n", "\n", "\n", "def check_drc(graph, design_rules: List[Callable[[rg.GraphModule], None]]):\n", " \"\"\"\n", " Perform a design rule check\n", " \"\"\"\n", " for dr in design_rules:\n", " try:\n", " dr(graph)\n", " except DRCError as e:\n", " raise DRCError(\n", " f\"Design rule {dr.__name__} triggered an error:\\n\"\n", " + \"\".join([f\"{msg}\" for msg in e.args])\n", " )\n", "\n", "\n", "# - To perform the DRC check, use the function like so:\n", "# check_drc(graph, xylo_drc)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:wavesense]", "language": "python", "name": "conda-env-wavesense-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 4 }