{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🔥 Building Rockpool modules with Torch" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Rockpool provides torch-backed modules with standard dynamics, for simple integration with other torch-provided modules from ``torch.nn``.\n", "\n", "======================================= =======================================\n", "Class Description\n", "======================================= =======================================\n", ":py:class:`.RateTorch` A layer of non-spiking firing-rate neurons, with trainable time constants, thresholds and biases per neuron; optinally supporting recurrent connectivity\n", ":py:class:`.LIFTorch` A layer of leaky integrate-and-fire spiking neurons, optionally supporting recurrent connectivity. Traininable with surrogate gradient descent, with trainable time constants, biases, thresholds per neuron\n", ":py:class:`.ExpSynTorch` Exponential synapses, with trainable time constants\n", ":py:class:`.LinearTorch` Equivalent to a standard trainable linear weights layer, but fully supporting the Rockpool APIs\n", ":py:class:`.InstantTorch` Wrap an arbitrary function as a Rockpool module\n", "======================================= =======================================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use the Rockpool Torch-backed classes" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The classes above can be used directly to build network architectures in Rockpool, including mixing classes from ``torch.nn``. Here we build a simple feed-forward dynamical rate network, including a dropout layer from torch." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# - Switch off warnings\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# - Rich printing\n", "try:\n", " from rich import print\n", "except:\n", " pass\n", "\n", "# - Import and configure matplotlib for plotting\n", "import sys\n", "!{sys.executable} -m pip install --quiet matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "plt.rcParams[\"figure.figsize\"] = [12, 4]\n", "plt.rcParams[\"figure.dpi\"] = 300\n", "\n", "# - Torch imports\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TorchSequential with shape (2, 2) {\n", " LinearTorch '0_LinearTorch' with shape (2, 5)\n", " RateTorch '1_RateTorch' with shape (5,)\n", " Dropout2d '2_Dropout2d' with shape (None,)\n", " LinearTorch '3_LinearTorch' with shape (5, 2)\n", " RateTorch '4_RateTorch' with shape (2,)\n", "}" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from rockpool.nn.modules import RateTorch, LinearTorch\n", "from rockpool.nn.combinators import Sequential\n", "\n", "Nin = 2\n", "Nhidden = 5\n", "Nout = 2\n", "\n", "# Define a simple feed-forward network using the Torch backend\n", "net = Sequential(\n", " LinearTorch((Nin, Nhidden)),\n", " RateTorch((Nhidden,)),\n", " nn.Dropout2d(0.25),\n", " LinearTorch((Nhidden, Nout)),\n", " RateTorch((Nout,)),\n", ")\n", "net" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# - Evolve the network on random data and plot\n", "data = torch.rand((1, 100, Nin))\n", "out, _, _ = net(data)\n", "plt.plot(out[0].detach());" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[\n",
       "    '0_LinearTorch',\n",
       "    '0_LinearTorch_output',\n",
       "    '1_RateTorch',\n",
       "    '1_RateTorch_output',\n",
       "    '2_Dropout2d',\n",
       "    '2_Dropout2d_output',\n",
       "    '3_LinearTorch',\n",
       "    '3_LinearTorch_output',\n",
       "    '4_RateTorch',\n",
       "    '4_RateTorch_output'\n",
       "]\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\n", " \u001b[32m'0_LinearTorch'\u001b[0m,\n", " \u001b[32m'0_LinearTorch_output'\u001b[0m,\n", " \u001b[32m'1_RateTorch'\u001b[0m,\n", " \u001b[32m'1_RateTorch_output'\u001b[0m,\n", " \u001b[32m'2_Dropout2d'\u001b[0m,\n", " \u001b[32m'2_Dropout2d_output'\u001b[0m,\n", " \u001b[32m'3_LinearTorch'\u001b[0m,\n", " \u001b[32m'3_LinearTorch_output'\u001b[0m,\n", " \u001b[32m'4_RateTorch'\u001b[0m,\n", " \u001b[32m'4_RateTorch_output'\u001b[0m\n", "\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Recording internal signals also works\n", "out, _, rd = net(data, record=True)\n", "print(list(rd.keys()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert an existing Torch ``torch.nn.module`` for use in Rockpool" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Torch modules implemented using ``torch.nn.Module`` can be converted directly to the Rockpool API using the method :py:meth:`.TorchModule.from_torch`. This method returns an object adhering to the Rockpool low-level API, converting Torch calls and attributes into Rockpool calls and registered attributes.\n", "\n", "Here we show an example of a simple Torch module coverted to a Rockpool object." ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "# - Torch imports\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "# - Rockpool imports\n", "from rockpool.nn.modules import TorchModule\n", "\n", "# - Implement a Torch class\n", "class TorchNet(torch.nn.Module):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", "\n", " # - Build some convolutional layers\n", " self.conv1 = nn.Conv2d(1, 2, 3, 1)\n", "\n", " # - Add a dropout layer\n", " self.dropout1 = nn.Dropout2d(0.25)\n", "\n", " # - Fully-connected layer\n", " self.fc1 = nn.Linear(338, 10)\n", "\n", " # - Register an example buffer\n", " self.register_buffer(\"test_buf\", torch.zeros(3, 4))\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", "\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout1(x)\n", "\n", " x = torch.flatten(x, 1)\n", "\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", "\n", " output = F.log_softmax(x, dim=1)\n", " return output" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "# - Instantiate the network and test the Torch API\n", "\n", "# Equates to one random 28x28 image\n", "random_data = torch.rand((1, 1, 28, 28))\n", "\n", "# - Generate torch module and test evaluation\n", "mod = TorchNet()\n", "result = mod(random_data)" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
TorchNet 'TorchModulePatch' with shape (None,) {\n",
       "    Conv2d 'TorchModulePatch' with shape (None,)\n",
       "    Dropout2d 'TorchModulePatch' with shape (None,)\n",
       "    Linear 'TorchModulePatch' with shape (None,)\n",
       "}\n",
       "
\n" ], "text/plain": [ "TorchNet \u001b[32m'TorchModulePatch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " Conv2d \u001b[32m'TorchModulePatch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Dropout2d \u001b[32m'TorchModulePatch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Linear \u001b[32m'TorchModulePatch'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Convert object to Rockpool API, in-place\n", "TorchModule.from_torch(mod)\n", "print(mod)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
tensor([[-2.2458, -1.9656, -2.3617, -2.3617, -2.3617, -2.3617, -2.3617, -2.3617,\n",
       "         -2.3617, -2.3617]], grad_fn=<LogSoftmaxBackward0>)\n",
       "
\n" ], "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-2.2458\u001b[0m, \u001b[1;36m-1.9656\u001b[0m, \u001b[1;36m-2.3617\u001b[0m, \u001b[1;36m-2.3617\u001b[0m, \u001b[1;36m-2.3617\u001b[0m, \u001b[1;36m-2.3617\u001b[0m, \u001b[1;36m-2.3617\u001b[0m, \u001b[1;36m-2.3617\u001b[0m,\n", " \u001b[1;36m-2.3617\u001b[0m, \u001b[1;36m-2.3617\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mLogSoftmaxBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Use the Rockpool API to evolve the module\n", "output, _, _ = mod(random_data)\n", "print(output)" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "The module attributes can be accessed using the Rockpool API via :py:meth:`~.TorchModule.parameters`, :py:meth:`~.TorchModule.state` and :py:meth:`~.TorchModule.simulationparameters` methods. The attribute dictionaries returned by these methods support an additional method :py:meth:`~.TorchModuleParameters.astorch`, which converts the attribute dictionary to a generator returning raw :py:class:`Tensor` s. Doing so is equivalent to calling the :py:meth:`Torch.nn.Module.parameters` method." ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters: \n",
       "{\n",
       "    'conv1': {\n",
       "        'weight': Parameter containing:\n",
       "tensor([[[[-0.1636, -0.3071,  0.0886],\n",
       "          [ 0.1826, -0.0988, -0.2805],\n",
       "          [ 0.1841, -0.1396, -0.0389]]],\n",
       "\n",
       "\n",
       "        [[[ 0.2814, -0.2359, -0.0974],\n",
       "          [-0.2386,  0.3125,  0.1958],\n",
       "          [ 0.3165,  0.0791, -0.0173]]]], requires_grad=True),\n",
       "        'bias': Parameter containing:\n",
       "tensor([-0.1282,  0.0541], requires_grad=True)\n",
       "    },\n",
       "    'dropout1': {},\n",
       "    'fc1': {\n",
       "        'weight': Parameter containing:\n",
       "tensor([[ 0.0261, -0.0336, -0.0126,  ..., -0.0055,  0.0495, -0.0540],\n",
       "        [ 0.0469,  0.0163, -0.0500,  ..., -0.0408, -0.0364, -0.0121],\n",
       "        [-0.0284,  0.0247,  0.0290,  ..., -0.0128,  0.0444,  0.0534],\n",
       "        ...,\n",
       "        [ 0.0184,  0.0500,  0.0326,  ..., -0.0116, -0.0092,  0.0071],\n",
       "        [ 0.0190,  0.0424, -0.0505,  ..., -0.0379, -0.0238,  0.0469],\n",
       "        [ 0.0119,  0.0063,  0.0538,  ..., -0.0211, -0.0373,  0.0374]],\n",
       "       requires_grad=True),\n",
       "        'bias': Parameter containing:\n",
       "tensor([-0.0106, -0.0069, -0.0463,  0.0346,  0.0390, -0.0147, -0.0386, -0.0023,\n",
       "         0.0171, -0.0368], requires_grad=True)\n",
       "    }\n",
       "}\n",
       "
\n" ], "text/plain": [ "Parameters: \n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'conv1'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1636\u001b[0m, \u001b[1;36m-0.3071\u001b[0m, \u001b[1;36m0.0886\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1826\u001b[0m, \u001b[1;36m-0.0988\u001b[0m, \u001b[1;36m-0.2805\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m-0.1396\u001b[0m, \u001b[1;36m-0.0389\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", "\n", "\n", " \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.2814\u001b[0m, \u001b[1;36m-0.2359\u001b[0m, \u001b[1;36m-0.0974\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.2386\u001b[0m, \u001b[1;36m0.3125\u001b[0m, \u001b[1;36m0.1958\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.3165\u001b[0m, \u001b[1;36m0.0791\u001b[0m, \u001b[1;36m-0.0173\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1282\u001b[0m, \u001b[1;36m0.0541\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'dropout1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'fc1'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0261\u001b[0m, \u001b[1;36m-0.0336\u001b[0m, \u001b[1;36m-0.0126\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0055\u001b[0m, \u001b[1;36m0.0495\u001b[0m, \u001b[1;36m-0.0540\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0469\u001b[0m, \u001b[1;36m0.0163\u001b[0m, \u001b[1;36m-0.0500\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0408\u001b[0m, \u001b[1;36m-0.0364\u001b[0m, \u001b[1;36m-0.0121\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0284\u001b[0m, \u001b[1;36m0.0247\u001b[0m, \u001b[1;36m0.0290\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0128\u001b[0m, \u001b[1;36m0.0444\u001b[0m, \u001b[1;36m0.0534\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33m...\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0184\u001b[0m, \u001b[1;36m0.0500\u001b[0m, \u001b[1;36m0.0326\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0116\u001b[0m, \u001b[1;36m-0.0092\u001b[0m, \u001b[1;36m0.0071\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0190\u001b[0m, \u001b[1;36m0.0424\u001b[0m, \u001b[1;36m-0.0505\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0379\u001b[0m, \u001b[1;36m-0.0238\u001b[0m, \u001b[1;36m0.0469\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0119\u001b[0m, \u001b[1;36m0.0063\u001b[0m, \u001b[1;36m0.0538\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0211\u001b[0m, \u001b[1;36m-0.0373\u001b[0m, \u001b[1;36m0.0374\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.0106\u001b[0m, \u001b[1;36m-0.0069\u001b[0m, \u001b[1;36m-0.0463\u001b[0m, \u001b[1;36m0.0346\u001b[0m, \u001b[1;36m0.0390\u001b[0m, \u001b[1;36m-0.0147\u001b[0m, \u001b[1;36m-0.0386\u001b[0m, \u001b[1;36m-0.0023\u001b[0m,\n", " \u001b[1;36m0.0171\u001b[0m, \u001b[1;36m-0.0368\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
State: \n",
       "{\n",
       "    'test_buf': tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]]),\n",
       "    'conv1': {},\n",
       "    'dropout1': {},\n",
       "    'fc1': {}\n",
       "}\n",
       "
\n" ], "text/plain": [ "State: \n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'test_buf'\u001b[0m: \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'conv1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'dropout1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'fc1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Use the Rockpool API to access parameters\n", "print(\"Parameters: \", mod.parameters())\n", "print(\"State: \", mod.state())" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters.astorch(): \n",
       "[\n",
       "    Parameter containing:\n",
       "tensor([[[[-0.1636, -0.3071,  0.0886],\n",
       "          [ 0.1826, -0.0988, -0.2805],\n",
       "          [ 0.1841, -0.1396, -0.0389]]],\n",
       "\n",
       "\n",
       "        [[[ 0.2814, -0.2359, -0.0974],\n",
       "          [-0.2386,  0.3125,  0.1958],\n",
       "          [ 0.3165,  0.0791, -0.0173]]]], requires_grad=True),\n",
       "    Parameter containing:\n",
       "tensor([-0.1282,  0.0541], requires_grad=True),\n",
       "    Parameter containing:\n",
       "tensor([[ 0.0261, -0.0336, -0.0126,  ..., -0.0055,  0.0495, -0.0540],\n",
       "        [ 0.0469,  0.0163, -0.0500,  ..., -0.0408, -0.0364, -0.0121],\n",
       "        [-0.0284,  0.0247,  0.0290,  ..., -0.0128,  0.0444,  0.0534],\n",
       "        ...,\n",
       "        [ 0.0184,  0.0500,  0.0326,  ..., -0.0116, -0.0092,  0.0071],\n",
       "        [ 0.0190,  0.0424, -0.0505,  ..., -0.0379, -0.0238,  0.0469],\n",
       "        [ 0.0119,  0.0063,  0.0538,  ..., -0.0211, -0.0373,  0.0374]],\n",
       "       requires_grad=True),\n",
       "    Parameter containing:\n",
       "tensor([-0.0106, -0.0069, -0.0463,  0.0346,  0.0390, -0.0147, -0.0386, -0.0023,\n",
       "         0.0171, -0.0368], requires_grad=True)\n",
       "]\n",
       "
\n" ], "text/plain": [ "\u001b[1;35mParameters.astorch\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m: \n", "\u001b[1m[\u001b[0m\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1636\u001b[0m, \u001b[1;36m-0.3071\u001b[0m, \u001b[1;36m0.0886\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1826\u001b[0m, \u001b[1;36m-0.0988\u001b[0m, \u001b[1;36m-0.2805\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m-0.1396\u001b[0m, \u001b[1;36m-0.0389\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", "\n", "\n", " \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.2814\u001b[0m, \u001b[1;36m-0.2359\u001b[0m, \u001b[1;36m-0.0974\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.2386\u001b[0m, \u001b[1;36m0.3125\u001b[0m, \u001b[1;36m0.1958\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.3165\u001b[0m, \u001b[1;36m0.0791\u001b[0m, \u001b[1;36m-0.0173\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1282\u001b[0m, \u001b[1;36m0.0541\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0261\u001b[0m, \u001b[1;36m-0.0336\u001b[0m, \u001b[1;36m-0.0126\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0055\u001b[0m, \u001b[1;36m0.0495\u001b[0m, \u001b[1;36m-0.0540\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0469\u001b[0m, \u001b[1;36m0.0163\u001b[0m, \u001b[1;36m-0.0500\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0408\u001b[0m, \u001b[1;36m-0.0364\u001b[0m, \u001b[1;36m-0.0121\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0284\u001b[0m, \u001b[1;36m0.0247\u001b[0m, \u001b[1;36m0.0290\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0128\u001b[0m, \u001b[1;36m0.0444\u001b[0m, \u001b[1;36m0.0534\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33m...\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0184\u001b[0m, \u001b[1;36m0.0500\u001b[0m, \u001b[1;36m0.0326\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0116\u001b[0m, \u001b[1;36m-0.0092\u001b[0m, \u001b[1;36m0.0071\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0190\u001b[0m, \u001b[1;36m0.0424\u001b[0m, \u001b[1;36m-0.0505\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0379\u001b[0m, \u001b[1;36m-0.0238\u001b[0m, \u001b[1;36m0.0469\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0119\u001b[0m, \u001b[1;36m0.0063\u001b[0m, \u001b[1;36m0.0538\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0211\u001b[0m, \u001b[1;36m-0.0373\u001b[0m, \u001b[1;36m0.0374\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.0106\u001b[0m, \u001b[1;36m-0.0069\u001b[0m, \u001b[1;36m-0.0463\u001b[0m, \u001b[1;36m0.0346\u001b[0m, \u001b[1;36m0.0390\u001b[0m, \u001b[1;36m-0.0147\u001b[0m, \u001b[1;36m-0.0386\u001b[0m, \u001b[1;36m-0.0023\u001b[0m,\n", " \u001b[1;36m0.0171\u001b[0m, \u001b[1;36m-0.0368\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Convert the parameter dictionary to torch parameters\n", "print(\"Parameters.astorch(): \", list(mod.parameters().astorch()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write a native Rockpool/Torch module using ``TorchModule``" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "You can also use :py:class:`.TorchModule` directly as a base class, in place of ``torch.nn.Module``. Usually this will be a drop-in replacement, without modifying the initialisation or evaluation code.\n", "\n", "The example here mimics the network above --- only the inherited base class has been changed." ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "# - Implement a Rockpool class using the TorchModule base class\n", "class RockpoolNet(TorchModule):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", "\n", " # - Build some convolutional layers\n", " self.conv1 = nn.Conv2d(1, 2, 3, 1)\n", "\n", " # - Add a dropout layer\n", " self.dropout1 = nn.Dropout2d(0.25)\n", "\n", " # - Fully-connected layer\n", " self.fc1 = nn.Linear(338, 10)\n", "\n", " # - Register an example buffer\n", " self.register_buffer(\"test_buf\", torch.zeros(3, 4))\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", "\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout1(x)\n", "\n", " x = torch.flatten(x, 1)\n", "\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", "\n", " output = F.log_softmax(x, dim=1)\n", " return output" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RockpoolNet  with shape (None,) {\n",
       "    Conv2d 'conv1' with shape (None,)\n",
       "    Conv2d 'conv1' with shape (None,)\n",
       "    Dropout2d 'dropout1' with shape (None,)\n",
       "    Dropout2d 'dropout1' with shape (None,)\n",
       "    Linear 'fc1' with shape (None,)\n",
       "    Linear 'fc1' with shape (None,)\n",
       "}\n",
       "
\n" ], "text/plain": [ "RockpoolNet with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " Conv2d \u001b[32m'conv1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Conv2d \u001b[32m'conv1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Dropout2d \u001b[32m'dropout1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Dropout2d \u001b[32m'dropout1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Linear \u001b[32m'fc1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Linear \u001b[32m'fc1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Instantiate the Rockpool class directly\n", "rmod = RockpoolNet()\n", "print(rmod)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
tensor([[-2.3283, -2.3283, -2.3283, -2.3283, -2.3283, -2.2359, -2.2393, -2.2600,\n",
       "         -2.3283, -2.3283]], grad_fn=<LogSoftmaxBackward0>)\n",
       "
\n" ], "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-2.3283\u001b[0m, \u001b[1;36m-2.3283\u001b[0m, \u001b[1;36m-2.3283\u001b[0m, \u001b[1;36m-2.3283\u001b[0m, \u001b[1;36m-2.3283\u001b[0m, \u001b[1;36m-2.2359\u001b[0m, \u001b[1;36m-2.2393\u001b[0m, \u001b[1;36m-2.2600\u001b[0m,\n", " \u001b[1;36m-2.3283\u001b[0m, \u001b[1;36m-2.3283\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mLogSoftmaxBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Evaluate the module using the Rockpool API\n", "output, _, _ = rmod(random_data)\n", "print(output)" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters: \n",
       "{\n",
       "    'conv1': {\n",
       "        'weight': Parameter containing:\n",
       "tensor([[[[ 0.2828, -0.2562,  0.0924],\n",
       "          [ 0.1476,  0.0577, -0.0121],\n",
       "          [ 0.0073,  0.2832, -0.2923]]],\n",
       "\n",
       "\n",
       "        [[[ 0.2786,  0.0777, -0.0933],\n",
       "          [-0.1199,  0.0570, -0.2343],\n",
       "          [ 0.2991,  0.0064, -0.2985]]]], requires_grad=True),\n",
       "        'bias': Parameter containing:\n",
       "tensor([-0.2721, -0.2435], requires_grad=True)\n",
       "    },\n",
       "    'dropout1': {},\n",
       "    'fc1': {\n",
       "        'weight': Parameter containing:\n",
       "tensor([[-0.0320,  0.0249, -0.0329,  ...,  0.0196,  0.0241, -0.0021],\n",
       "        [-0.0400,  0.0448, -0.0266,  ..., -0.0056, -0.0111,  0.0318],\n",
       "        [ 0.0204,  0.0127, -0.0184,  ...,  0.0482, -0.0074,  0.0258],\n",
       "        ...,\n",
       "        [ 0.0328, -0.0477,  0.0297,  ..., -0.0182, -0.0296,  0.0217],\n",
       "        [-0.0037,  0.0421,  0.0048,  ...,  0.0057, -0.0409, -0.0112],\n",
       "        [-0.0214, -0.0465, -0.0319,  ...,  0.0304, -0.0220, -0.0306]],\n",
       "       requires_grad=True),\n",
       "        'bias': Parameter containing:\n",
       "tensor([ 0.0059,  0.0042, -0.0443,  0.0200,  0.0258,  0.0406,  0.0433, -0.0303,\n",
       "         0.0038,  0.0473], requires_grad=True)\n",
       "    }\n",
       "}\n",
       "
\n" ], "text/plain": [ "Parameters: \n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'conv1'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.2828\u001b[0m, \u001b[1;36m-0.2562\u001b[0m, \u001b[1;36m0.0924\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.1476\u001b[0m, \u001b[1;36m0.0577\u001b[0m, \u001b[1;36m-0.0121\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0073\u001b[0m, \u001b[1;36m0.2832\u001b[0m, \u001b[1;36m-0.2923\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", "\n", "\n", " \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.2786\u001b[0m, \u001b[1;36m0.0777\u001b[0m, \u001b[1;36m-0.0933\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.1199\u001b[0m, \u001b[1;36m0.0570\u001b[0m, \u001b[1;36m-0.2343\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.2991\u001b[0m, \u001b[1;36m0.0064\u001b[0m, \u001b[1;36m-0.2985\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.2721\u001b[0m, \u001b[1;36m-0.2435\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'dropout1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'fc1'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.0320\u001b[0m, \u001b[1;36m0.0249\u001b[0m, \u001b[1;36m-0.0329\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0196\u001b[0m, \u001b[1;36m0.0241\u001b[0m, \u001b[1;36m-0.0021\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0400\u001b[0m, \u001b[1;36m0.0448\u001b[0m, \u001b[1;36m-0.0266\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0056\u001b[0m, \u001b[1;36m-0.0111\u001b[0m, \u001b[1;36m0.0318\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0204\u001b[0m, \u001b[1;36m0.0127\u001b[0m, \u001b[1;36m-0.0184\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0482\u001b[0m, \u001b[1;36m-0.0074\u001b[0m, \u001b[1;36m0.0258\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33m...\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0328\u001b[0m, \u001b[1;36m-0.0477\u001b[0m, \u001b[1;36m0.0297\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0182\u001b[0m, \u001b[1;36m-0.0296\u001b[0m, \u001b[1;36m0.0217\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0037\u001b[0m, \u001b[1;36m0.0421\u001b[0m, \u001b[1;36m0.0048\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0057\u001b[0m, \u001b[1;36m-0.0409\u001b[0m, \u001b[1;36m-0.0112\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0214\u001b[0m, \u001b[1;36m-0.0465\u001b[0m, \u001b[1;36m-0.0319\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0304\u001b[0m, \u001b[1;36m-0.0220\u001b[0m, \u001b[1;36m-0.0306\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0059\u001b[0m, \u001b[1;36m0.0042\u001b[0m, \u001b[1;36m-0.0443\u001b[0m, \u001b[1;36m0.0200\u001b[0m, \u001b[1;36m0.0258\u001b[0m, \u001b[1;36m0.0406\u001b[0m, \u001b[1;36m0.0433\u001b[0m, \u001b[1;36m-0.0303\u001b[0m,\n", " \u001b[1;36m0.0038\u001b[0m, \u001b[1;36m0.0473\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
State: \n",
       "{\n",
       "    'test_buf': tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]]),\n",
       "    'conv1': {},\n",
       "    'dropout1': {},\n",
       "    'fc1': {}\n",
       "}\n",
       "
\n" ], "text/plain": [ "State: \n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'test_buf'\u001b[0m: \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'conv1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'dropout1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'fc1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Access parameters using the Rockpool API\n", "print(\"Parameters: \", rmod.parameters())\n", "print(\"State: \", rmod.state())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Converting from Rockpool/torch to pure torch" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Sometimes you may want to use the Rockpool provided :py:class:`.TorchModule` derived classes with other software that expects pure torch (e.g. MLFlow or Pytorch Lightning).\n", "\n", "In that case you can use the :py:meth:`~.TorchModule.to_torch` method to expose a pure torch interface. Here we show how that works, using the class ``RockpoolNet`` defined above." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Rockpool API" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RockpoolNet  with shape (None,) {\n",
       "    Conv2d 'conv1' with shape (None,)\n",
       "    Conv2d 'conv1' with shape (None,)\n",
       "    Dropout2d 'dropout1' with shape (None,)\n",
       "    Dropout2d 'dropout1' with shape (None,)\n",
       "    Linear 'fc1' with shape (None,)\n",
       "    Linear 'fc1' with shape (None,)\n",
       "}\n",
       "
\n" ], "text/plain": [ "RockpoolNet with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m{\u001b[0m\n", " Conv2d \u001b[32m'conv1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Conv2d \u001b[32m'conv1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Dropout2d \u001b[32m'dropout1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Dropout2d \u001b[32m'dropout1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Linear \u001b[32m'fc1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " Linear \u001b[32m'fc1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Instantiate the Rockpool class\n", "net = RockpoolNet()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters:\n",
       "{\n",
       "    'conv1': {\n",
       "        'weight': Parameter containing:\n",
       "tensor([[[[-0.2905,  0.1654, -0.1893],\n",
       "          [-0.0804,  0.1523, -0.3320],\n",
       "          [ 0.0346,  0.1020,  0.0288]]],\n",
       "\n",
       "\n",
       "        [[[-0.1374,  0.3117,  0.0558],\n",
       "          [-0.3279,  0.1651,  0.3008],\n",
       "          [ 0.0011,  0.1701, -0.1425]]]], requires_grad=True),\n",
       "        'bias': Parameter containing:\n",
       "tensor([-0.1574, -0.1525], requires_grad=True)\n",
       "    },\n",
       "    'dropout1': {},\n",
       "    'fc1': {\n",
       "        'weight': Parameter containing:\n",
       "tensor([[ 0.0269, -0.0310,  0.0103,  ..., -0.0044, -0.0114, -0.0388],\n",
       "        [ 0.0409, -0.0286,  0.0256,  ...,  0.0166,  0.0072, -0.0476],\n",
       "        [-0.0356,  0.0108, -0.0136,  ..., -0.0108,  0.0268,  0.0322],\n",
       "        ...,\n",
       "        [-0.0467, -0.0277,  0.0084,  ..., -0.0023,  0.0034, -0.0107],\n",
       "        [-0.0191,  0.0524, -0.0005,  ...,  0.0305,  0.0221, -0.0304],\n",
       "        [-0.0199,  0.0537,  0.0393,  ..., -0.0248,  0.0081,  0.0438]],\n",
       "       requires_grad=True),\n",
       "        'bias': Parameter containing:\n",
       "tensor([ 0.0454,  0.0186, -0.0167, -0.0339,  0.0249,  0.0274, -0.0339, -0.0019,\n",
       "        -0.0250, -0.0271], requires_grad=True)\n",
       "    }\n",
       "}\n",
       "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m{\u001b[0m\n", " \u001b[32m'conv1'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.2905\u001b[0m, \u001b[1;36m0.1654\u001b[0m, \u001b[1;36m-0.1893\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0804\u001b[0m, \u001b[1;36m0.1523\u001b[0m, \u001b[1;36m-0.3320\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0346\u001b[0m, \u001b[1;36m0.1020\u001b[0m, \u001b[1;36m0.0288\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", "\n", "\n", " \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1374\u001b[0m, \u001b[1;36m0.3117\u001b[0m, \u001b[1;36m0.0558\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.3279\u001b[0m, \u001b[1;36m0.1651\u001b[0m, \u001b[1;36m0.3008\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0011\u001b[0m, \u001b[1;36m0.1701\u001b[0m, \u001b[1;36m-0.1425\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1574\u001b[0m, \u001b[1;36m-0.1525\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[32m'dropout1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'fc1'\u001b[0m: \u001b[1m{\u001b[0m\n", " \u001b[32m'weight'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0269\u001b[0m, \u001b[1;36m-0.0310\u001b[0m, \u001b[1;36m0.0103\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0044\u001b[0m, \u001b[1;36m-0.0114\u001b[0m, \u001b[1;36m-0.0388\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0409\u001b[0m, \u001b[1;36m-0.0286\u001b[0m, \u001b[1;36m0.0256\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0166\u001b[0m, \u001b[1;36m0.0072\u001b[0m, \u001b[1;36m-0.0476\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0356\u001b[0m, \u001b[1;36m0.0108\u001b[0m, \u001b[1;36m-0.0136\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0108\u001b[0m, \u001b[1;36m0.0268\u001b[0m, \u001b[1;36m0.0322\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33m...\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0467\u001b[0m, \u001b[1;36m-0.0277\u001b[0m, \u001b[1;36m0.0084\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0023\u001b[0m, \u001b[1;36m0.0034\u001b[0m, \u001b[1;36m-0.0107\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0191\u001b[0m, \u001b[1;36m0.0524\u001b[0m, \u001b[1;36m-0.0005\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0305\u001b[0m, \u001b[1;36m0.0221\u001b[0m, \u001b[1;36m-0.0304\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0199\u001b[0m, \u001b[1;36m0.0537\u001b[0m, \u001b[1;36m0.0393\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0248\u001b[0m, \u001b[1;36m0.0081\u001b[0m, \u001b[1;36m0.0438\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'bias'\u001b[0m: Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0454\u001b[0m, \u001b[1;36m0.0186\u001b[0m, \u001b[1;36m-0.0167\u001b[0m, \u001b[1;36m-0.0339\u001b[0m, \u001b[1;36m0.0249\u001b[0m, \u001b[1;36m0.0274\u001b[0m, \u001b[1;36m-0.0339\u001b[0m, \u001b[1;36m-0.0019\u001b[0m,\n", " \u001b[1;36m-0.0250\u001b[0m, \u001b[1;36m-0.0271\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", " \u001b[1m}\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Rockpool dictionary-based parameter API\n", "print(\"Parameters:\", net.parameters())" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
(\n",
       "    tensor([[-2.3239, -2.3239, -2.2561, -2.2330, -2.3239, -2.3239, -2.3239, -2.3239,\n",
       "         -2.2750, -2.3239]], grad_fn=<LogSoftmaxBackward0>),\n",
       "    {\n",
       "        'test_buf': tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]]),\n",
       "        'conv1': {},\n",
       "        'dropout1': {},\n",
       "        'fc1': {}\n",
       "    },\n",
       "    {}\n",
       ")\n",
       "
\n" ], "text/plain": [ "\u001b[1m(\u001b[0m\n", " \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-2.3239\u001b[0m, \u001b[1;36m-2.3239\u001b[0m, \u001b[1;36m-2.2561\u001b[0m, \u001b[1;36m-2.2330\u001b[0m, \u001b[1;36m-2.3239\u001b[0m, \u001b[1;36m-2.3239\u001b[0m, \u001b[1;36m-2.3239\u001b[0m, \u001b[1;36m-2.3239\u001b[0m,\n", " \u001b[1;36m-2.2750\u001b[0m, \u001b[1;36m-2.3239\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mLogSoftmaxBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m{\u001b[0m\n", " \u001b[32m'test_buf'\u001b[0m: \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m., \u001b[1;36m0\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[32m'conv1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'dropout1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[32m'fc1'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Evaluate one random 28x28 image\n", "random_data = torch.rand((1, 1, 28, 28))\n", "\n", "# - Rockpool standard calling semantics\n", "print(net(random_data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Torch API" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RockpoolNet(\n",
       "  (conv1): Conv2d 'conv1' with shape (None,)\n",
       "  (dropout1): Dropout2d 'dropout1' with shape (None,)\n",
       "  (fc1): Linear 'fc1' with shape (None,)\n",
       ")\n",
       "
\n" ], "text/plain": [ "\u001b[1;35mRockpoolNet\u001b[0m\u001b[1m(\u001b[0m\n", " \u001b[1m(\u001b[0mconv1\u001b[1m)\u001b[0m: Conv2d \u001b[32m'conv1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " \u001b[1m(\u001b[0mdropout1\u001b[1m)\u001b[0m: Dropout2d \u001b[32m'dropout1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", " \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: Linear \u001b[32m'fc1'\u001b[0m with shape \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m,\u001b[1m)\u001b[0m\n", "\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Convert in-place to the pure Torch API\n", "net.to_torch()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Parameters:\n",
       "[\n",
       "    Parameter containing:\n",
       "tensor([[[[-0.2905,  0.1654, -0.1893],\n",
       "          [-0.0804,  0.1523, -0.3320],\n",
       "          [ 0.0346,  0.1020,  0.0288]]],\n",
       "\n",
       "\n",
       "        [[[-0.1374,  0.3117,  0.0558],\n",
       "          [-0.3279,  0.1651,  0.3008],\n",
       "          [ 0.0011,  0.1701, -0.1425]]]], requires_grad=True),\n",
       "    Parameter containing:\n",
       "tensor([-0.1574, -0.1525], requires_grad=True),\n",
       "    Parameter containing:\n",
       "tensor([[ 0.0269, -0.0310,  0.0103,  ..., -0.0044, -0.0114, -0.0388],\n",
       "        [ 0.0409, -0.0286,  0.0256,  ...,  0.0166,  0.0072, -0.0476],\n",
       "        [-0.0356,  0.0108, -0.0136,  ..., -0.0108,  0.0268,  0.0322],\n",
       "        ...,\n",
       "        [-0.0467, -0.0277,  0.0084,  ..., -0.0023,  0.0034, -0.0107],\n",
       "        [-0.0191,  0.0524, -0.0005,  ...,  0.0305,  0.0221, -0.0304],\n",
       "        [-0.0199,  0.0537,  0.0393,  ..., -0.0248,  0.0081,  0.0438]],\n",
       "       requires_grad=True),\n",
       "    Parameter containing:\n",
       "tensor([ 0.0454,  0.0186, -0.0167, -0.0339,  0.0249,  0.0274, -0.0339, -0.0019,\n",
       "        -0.0250, -0.0271], requires_grad=True)\n",
       "]\n",
       "
\n" ], "text/plain": [ "Parameters:\n", "\u001b[1m[\u001b[0m\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.2905\u001b[0m, \u001b[1;36m0.1654\u001b[0m, \u001b[1;36m-0.1893\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0804\u001b[0m, \u001b[1;36m0.1523\u001b[0m, \u001b[1;36m-0.3320\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0346\u001b[0m, \u001b[1;36m0.1020\u001b[0m, \u001b[1;36m0.0288\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", "\n", "\n", " \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1374\u001b[0m, \u001b[1;36m0.3117\u001b[0m, \u001b[1;36m0.0558\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.3279\u001b[0m, \u001b[1;36m0.1651\u001b[0m, \u001b[1;36m0.3008\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0011\u001b[0m, \u001b[1;36m0.1701\u001b[0m, \u001b[1;36m-0.1425\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.1574\u001b[0m, \u001b[1;36m-0.1525\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0269\u001b[0m, \u001b[1;36m-0.0310\u001b[0m, \u001b[1;36m0.0103\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0044\u001b[0m, \u001b[1;36m-0.0114\u001b[0m, \u001b[1;36m-0.0388\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m0.0409\u001b[0m, \u001b[1;36m-0.0286\u001b[0m, \u001b[1;36m0.0256\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0166\u001b[0m, \u001b[1;36m0.0072\u001b[0m, \u001b[1;36m-0.0476\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0356\u001b[0m, \u001b[1;36m0.0108\u001b[0m, \u001b[1;36m-0.0136\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0108\u001b[0m, \u001b[1;36m0.0268\u001b[0m, \u001b[1;36m0.0322\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33m...\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0467\u001b[0m, \u001b[1;36m-0.0277\u001b[0m, \u001b[1;36m0.0084\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0023\u001b[0m, \u001b[1;36m0.0034\u001b[0m, \u001b[1;36m-0.0107\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0191\u001b[0m, \u001b[1;36m0.0524\u001b[0m, \u001b[1;36m-0.0005\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m0.0305\u001b[0m, \u001b[1;36m0.0221\u001b[0m, \u001b[1;36m-0.0304\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m-0.0199\u001b[0m, \u001b[1;36m0.0537\u001b[0m, \u001b[1;36m0.0393\u001b[0m, \u001b[33m...\u001b[0m, \u001b[1;36m-0.0248\u001b[0m, \u001b[1;36m0.0081\u001b[0m, \u001b[1;36m0.0438\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m,\n", " Parameter containing:\n", "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.0454\u001b[0m, \u001b[1;36m0.0186\u001b[0m, \u001b[1;36m-0.0167\u001b[0m, \u001b[1;36m-0.0339\u001b[0m, \u001b[1;36m0.0249\u001b[0m, \u001b[1;36m0.0274\u001b[0m, \u001b[1;36m-0.0339\u001b[0m, \u001b[1;36m-0.0019\u001b[0m,\n", " \u001b[1;36m-0.0250\u001b[0m, \u001b[1;36m-0.0271\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mrequires_grad\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# - Now returns the torch parameters API\n", "print(\"Parameters:\", list(net.parameters()))" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
tensor([[-2.3227, -2.3227, -2.3227, -2.2396, -2.3227, -2.3227, -2.3227, -2.3227,\n",
       "         -2.2124, -2.3227]], grad_fn=<LogSoftmaxBackward0>)\n",
       "
\n" ], "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-2.3227\u001b[0m, \u001b[1;36m-2.3227\u001b[0m, \u001b[1;36m-2.3227\u001b[0m, \u001b[1;36m-2.2396\u001b[0m, \u001b[1;36m-2.3227\u001b[0m, \u001b[1;36m-2.3227\u001b[0m, \u001b[1;36m-2.3227\u001b[0m, \u001b[1;36m-2.3227\u001b[0m,\n", " \u001b[1;36m-2.2124\u001b[0m, \u001b[1;36m-2.3227\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mgrad_fn\u001b[0m=\u001b[1m<\u001b[0m\u001b[1;95mLogSoftmaxBackward0\u001b[0m\u001b[1m>\u001b[0m\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Evaluate one random 28x28 image\n", "random_data = torch.rand((1, 1, 28, 28))\n", "\n", "# - Now uses torch calling semantics\n", "print(net(random_data))" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:py38]", "language": "python", "name": "conda-env-py38-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 4 }