This page was generated from docs/advanced/QuantTorch.ipynb. Interactive online version: Binder badge

Torch transformation-in-training pipeline prototype

This notebook gives an overview of the prototype parameter and activation quantization-aware-training pipeline and facilities available for Torch-backed modules in Rockpool.

This is still work-in-progress and subject to change.

The torch pipeline is based on Torch’s functional_call API, new in Torch 1.12.

Design goals

  • No need to modify pre-defined modules to make “magic quantization” modules

  • General solution that can be applied widely to modules and parameters

  • Convenient API for specifying transformations over parameters in a network in a “grouped” way, using Rockpool’s parameter families

  • Similar API for parameter- and activity-transformation

  • Quantization controllable at a fine-grained level

  • Provide useful and flexible transformation methods — can be used for QAT, dropout, pruning…

[1]:
# - Basic imports
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential, Residual

import torch
[2]:
# - Transformation pipeline imports
import rockpool.transform.torch_transform as tt
import rockpool.utilities.tree_utils as tu

Parameter transformations

The parameter transformation pipeline allows you to insert transformations to any parameter in the forward pass before evolution, in a configurable way. You would use this to perform quantisation-aware-training, random parameter attacks, connection pruning, …

We’ll begin here with a simple Rockpool SNN that uses most of the features of network composition in Rockpool, and is compatible with Xylo.

[3]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net
[3]:
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Now we build a configuration that describes the desired parameter transformation to apply to each parameter. We will transform weights with stochastic_rounding() and transform biases with dropout(). We can use parameter families to select the parameters to transform and which transformation to apply.

[4]:
# - Get the 'weights' parameter family, and specify stochastic rounding
tconfig = tt.make_param_T_config(
    net, lambda p: tt.stochastic_rounding(p, num_levels=2**2), "weights"
)
tconfig
[4]:
{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}
[5]:
# - Now we add in the bias transformation
tconfig = tu.tree_update(
    tconfig, tt.make_param_T_config(net, lambda p: tt.dropout(p, 0.3), "biases")
)
tconfig
[5]:
{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

We then use this quantization configuration tree to patch the network with transformation modules, with the make_param_T_network() helper function.

[6]:
# - We now use this configuration to patch the original network with transformation modules
tnet = tt.make_param_T_network(net, tconfig)
tnet
[6]:
TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Each of the transformed modules is now wrapped in a TWrapper module — these special wrapper modules apply any required transformations to the wrapped module, in the forward pass, injecting the transformed parameters and then evolving the wrapped module as usual. The original module doesn’t need to know anything special, and simply uses the quantized parameters passed to it.

The parameters are held by the original modules, un-transofmred, so that any parameters updates during training are applied to the un-transformed parameters.

If we investigate the Module.parameters() of the network we can see this structure:

[7]:
tnet.parameters("weights")
[7]:
{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-0.8337, -0.4811, -1.0898, -0.7309, -0.2193],
           [ 0.8017,  0.7001, -0.1010, -1.2180,  1.4088],
           [-1.3568, -0.6440,  0.1773, -1.0603,  0.7051]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[ 0.2452, -0.1040, -0.6809, -0.5192, -1.0748],
            [-0.7669, -0.2566,  0.6481, -0.1794, -0.1849],
            [ 0.5732, -0.7922, -0.7750,  0.7075, -0.5805],
            [ 0.5087, -0.1379, -0.3209,  0.3345,  0.2006],
            [ 0.9543, -1.0458,  0.9624, -0.3144,  0.1960]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[-0.5553,  0.5043, -0.8002,  1.0671,  0.5552],
            [ 0.0891, -1.0943, -0.3825, -0.7623,  0.8179],
            [ 0.3327,  0.8391,  0.6506,  0.1105, -0.1883],
            [-0.0614, -0.9433, -0.6902, -0.4699,  0.3375],
            [-0.4792,  0.7939,  0.0106,  0.4522, -0.2737]], requires_grad=True)}}},
 '3_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-0.9582,  0.9870, -0.3827],
           [ 0.5585, -0.9113, -0.1586],
           [ 0.8854,  0.8445,  0.0221],
           [-0.8204,  0.3860,  0.8635],
           [ 0.7997,  0.4473, -1.0423]], requires_grad=True)}}}

These are the un-transformed parameters, in floating-point format. But if we evolve the module by calling it, the parameters will all be transformed in the forward pass:

[8]:
out, ns, rd = tnet(torch.ones(1, 10, 3))
out
/home/dylan/miniconda3/envs/torch2/lib/python3.8/site-packages/torch/nn/utils/stateless.py:216: UserWarning: This API is deprecated as of PyTorch 2.0 and will be removed in a future version of PyTorch. Please use torch.func.functional_call instead which is a drop-in replacement for this API.
  warnings.warn(
[8]:
tensor([[[  0.,   2.,   0.],
         [  2.,   8.,   0.],
         [  0.,  17.,   0.],
         [  0.,  32.,   0.],
         [  0.,  47.,   0.],
         [  0.,  65.,   0.],
         [  0.,  89.,   4.],
         [  0., 117.,  27.],
         [  0., 154.,  46.],
         [  0., 200.,  70.]]], grad_fn=<CopySlices>)

Training goes here!

Here you can train the model, interacting with it as any other Rockpool TorchModule.

Once we’ve trained the model, you might want to access the transformed parameters. At this point you have two options:

  1. you can execute the transformation such that the parameters are updated manually, using the helper function apply_T(). This will “burn in” the transformation, storing the result as the “real” parameters within the module:

[9]:
ttnet = tt.apply_T(tnet, inplace=True)
ttnet
[9]:
TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

If we now examine the parameters, we will see the low-resolution quantized versions (still stored as floating-point numbers – this transformation did not force the parameters to be integers).

[10]:
ttnet.parameters("weights")
[10]:
{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-1.4088, -0.4696, -1.4088, -0.4696,  0.4696],
           [ 1.4088,  1.4088,  0.4696, -1.4088,  1.4088],
           [-1.4088, -0.4696, -0.4696, -1.4088,  1.4088]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[ 0.3583, -0.3583, -0.3583, -1.0748, -1.0748],
            [-1.0748, -0.3583,  0.3583, -0.3583, -0.3583],
            [ 0.3583, -0.3583, -0.3583,  0.3583, -0.3583],
            [ 0.3583, -0.3583, -0.3583,  0.3583,  0.3583],
            [ 1.0748, -1.0748,  1.0748, -0.3583,  0.3583]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[-1.0943,  0.3648, -1.0943,  1.0943,  0.3648],
            [ 0.3648, -1.0943, -0.3648, -1.0943,  0.3648],
            [ 0.3648,  0.3648,  0.3648,  0.3648, -0.3648],
            [ 0.3648, -1.0943, -0.3648, -0.3648,  0.3648],
            [-0.3648,  1.0943,  0.3648,  0.3648, -0.3648]], requires_grad=True)}}},
 '3_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-1.0423,  1.0423, -0.3474],
           [ 0.3474, -0.3474, -0.3474],
           [ 1.0423,  0.3474,  0.3474],
           [-0.3474,  0.3474,  1.0423],
           [ 1.0423,  0.3474, -1.0423]], requires_grad=True)}}}

You can now convert the network back to the original “unpatched” structure with the helper function remove_T_net().

[11]:
unpatched_net = tt.remove_T_net(ttnet, inplace=True)
unpatched_net
[11]:
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Compare this with the original network above.

  1. The second option is to “unpatch” the network with remove_T_net() and use post-training quantisation through whatever method you prefer. This might be preferable if you have included “destructive” transformations such as dropout().

How to: Quantize to round numbers

We might want to quantize to integer levels, for example when targetting processors that use integer logic and representations for parameters (such as Xylo). This is possible with stochastic_rounding().

The cell below shows you how to use stochastic_rounding() to target signed integer parameter values. By default, stochastic_rounding() makes sure that zero in the input space maps to a zero in the output space.

[12]:
w = torch.rand((5, 5)) - 0.5

num_bits = 4

tt.stochastic_rounding(
    w,
    output_range=[-(2 ** (num_bits - 1)) + 1, 2 ** (num_bits - 1)],
    num_levels=2**num_bits,
)
[12]:
tensor([[ 1.,  1.,  2., -3., -7.],
        [-5., -2., -5.,  7., -1.],
        [-2.,  4., -6.,  4., -5.],
        [-7.,  7.,  6., -1., -6.],
        [-2.,  4.,  4., -5.,  2.]])

Activity transformations

There is a similar pipeline available for activity transformations. This can be used to transform the output of modules in the forward pass, without modifying the module code.

Let’s begin again with a simple SNN artchitecture:

[13]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net
[13]:
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

We need to build a configuration to patch the network with. We can conveniently specify which modules to transform according to the module class. Here we’ll perform rounding of output activations to 8-bit signed integers, using the function deterministic_rounding().

[14]:
# - Build a null configuration tree, which can be manipulated directly
tt.make_act_T_config(net)

# - Specify a transformation function as a lambda
T_fn = lambda p: tt.deterministic_rounding(
    p, output_range=[-128, 127], num_levels=2**8
)

# - Conveniently build a configuration tree by selecting a module class
tconf = tt.make_act_T_config(net, T_fn, LinearTorch)
tconf
[14]:
{'': None,
 '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '1_LIFTorch': {'': None},
 '2_TorchResidual': {'': None,
  '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'': None}},
 '3_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '4_LIFTorch': {'': None}}

Now we patch the network, analogously to the parameter transformation above, using the helper function make_act_T_network():

[15]:
# - Make a transformed network by patching with the configuration
tnet = tt.make_act_T_network(net, tconf)
tnet
[15]:
TorchSequential  with shape (3, 3) {
    ActWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        ActWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    ActWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Again, the network has been patched (this time with ActWrapper modules), each of which handle the transformations for a single wrapped module.

Now we evolve the module as useful, and check the outputs of the LinearTorch layers:

[16]:
# - We evolve the module as usual
out, ns, rd = tnet(torch.ones(1, 10, 3), record=True)
[17]:
# - Examine the recorded outputs from the network; the LinearTorch layers have quantised output
rd
[17]:
{'0_LinearTorch': {},
 '0_LinearTorch_output': tensor([[[  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.]]], requires_grad=True),
 '1_LIFTorch': {'vmem': tensor([[[-1.9025e+00,  6.8360e-01, -9.9879e+01, -1.2176e+02,  1.2213e-01],
           [-5.5218e+00,  8.6280e-01, -2.8989e+02, -3.5340e+02,  5.2538e-01],
           [-1.0686e+01,  5.6142e-01, -5.6102e+02, -6.8391e+02,  4.0171e-01],
           [-1.7236e+01,  9.4843e-01, -9.0489e+02, -1.1031e+03,  2.1217e-02],
           [-2.5024e+01,  2.7499e-01, -1.3138e+03, -1.6015e+03,  6.7770e-01],
           [-3.3914e+01,  1.0071e-02, -1.7805e+03, -2.1705e+03,  6.3689e-01],
           [-4.3779e+01,  6.7685e-01, -2.2984e+03, -2.8019e+03,  2.8247e-01],
           [-5.4504e+01,  8.9288e-01, -2.8615e+03, -3.4883e+03,  1.1093e-02],
           [-6.5982e+01,  4.5712e-01, -3.4640e+03, -4.2228e+03,  2.3026e-01],
           [-7.8112e+01,  2.8668e-01, -4.1009e+03, -4.9992e+03,  3.5631e-01]]],
         grad_fn=<CopySlices>),
  'isyn': tensor([[[[  -1.9025],
            [  64.6836],
            [ -99.8791],
            [-121.7574],
            [  17.1221]],

           [[  -3.7121],
            [ 126.2125],
            [-194.8870],
            [-237.5766],
            [  33.4092]],

           [[  -5.4335],
            [ 184.7407],
            [-285.2614],
            [-347.7472],
            [  48.9020]],

           [[  -7.0710],
            [ 240.4144],
            [-371.2281],
            [-452.5447],
            [  63.6391]],

           [[  -8.6286],
            [ 293.3728],
            [-453.0022],
            [-552.2313],
            [  77.6575]],

           [[ -10.1103],
            [ 343.7485],
            [-530.7881],
            [-647.0560],
            [  90.9922]],

           [[ -11.5196],
            [ 391.6673],
            [-604.7804],
            [-737.2561],
            [ 103.6766]],

           [[ -12.8603],
            [ 437.2491],
            [-675.1640],
            [-823.0571],
            [ 115.7424]],

           [[ -14.1355],
            [ 480.6078],
            [-742.1150],
            [-904.6735],
            [ 127.2197]],

           [[ -15.3486],
            [ 521.8519],
            [-805.8007],
            [-982.3095],
            [ 138.1373]]]], grad_fn=<CopySlices>),
  'irec': tensor([[[[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.],
            [0.],
            [0.]]]]),
  'spikes': tensor([[[  0.,  64.,   0.,   0.,  17.],
           [  0., 126.,   0.,   0.,  33.],
           [  0., 185.,   0.,   0.,  49.],
           [  0., 240.,   0.,   0.,  64.],
           [  0., 294.,   0.,   0.,  77.],
           [  0., 344.,   0.,   0.,  91.],
           [  0., 391.,   0.,   0., 104.],
           [  0., 437.,   0.,   0., 116.],
           [  0., 481.,   0.,   0., 127.],
           [  0., 522.,   0.,   0., 138.]]], grad_fn=<CopySlices>)},
 '1_LIFTorch_output': tensor([[[  0.,  64.,   0.,   0.,  17.],
          [  0., 126.,   0.,   0.,  33.],
          [  0., 185.,   0.,   0.,  49.],
          [  0., 240.,   0.,   0.,  64.],
          [  0., 294.,   0.,   0.,  77.],
          [  0., 344.,   0.,   0.,  91.],
          [  0., 391.,   0.,   0., 104.],
          [  0., 437.,   0.,   0., 116.],
          [  0., 481.,   0.,   0., 127.],
          [  0., 522.,   0.,   0., 138.]]], requires_grad=True),
 '2_TorchResidual': {'0_LinearTorch': {},
  '0_LinearTorch_output': tensor([[[ 15.,   8.,  -9.,   5.,   3.],
           [ 30.,  15., -17.,  10.,   6.],
           [ 45.,  23., -24.,  15.,  10.],
           [ 58.,  30., -31.,  19.,  13.],
           [ 71.,  37., -39.,  24.,  16.],
           [ 84.,  43., -45.,  28.,  18.],
           [ 95.,  49., -51.,  32.,  21.],
           [106.,  55., -57.,  35.,  23.],
           [117.,  60., -63.,  39.,  26.],
           [127.,  65., -68.,  43.,  28.]]], requires_grad=True),
  '1_LIFTorch': {'vmem': tensor([[[ 2.6844e-01,  6.0984e-01, -8.5611e+00,  7.5615e-01,  8.5369e-01],
            [ 3.2919e-01,  4.6384e-01, -2.6536e+01,  9.9693e-01,  1.4157e-01],
            [ 3.7964e-01,  3.1346e-01, -4.6506e+01,  3.3242e-01,  7.7141e-01],
            [ 5.3272e-01,  5.6500e-01, -6.4679e+01,  1.0973e-01, -3.1371e+00],
            [ 8.4690e-01,  9.1185e-01, -7.3795e+01,  8.6328e-01, -1.5641e+01],
            [ 7.7775e-01,  8.4974e-01, -6.3736e+01,  1.4055e-01, -3.8865e+01],
            [ 1.8718e-01,  4.4014e-01, -1.6837e+01,  2.6741e-02, -7.1991e+01],
            [ 8.7653e-01, -2.1199e+00,  7.3981e-01,  2.7437e-01, -1.0837e+02],
            [ 5.1117e-02, -1.1532e+02,  3.3788e-01,  2.0024e-01, -2.1593e+02],
            [ 4.4446e-01, -4.9026e+02,  1.3531e-01,  3.2864e-01, -5.3796e+02]]],
          grad_fn=<CopySlices>),
   'isyn': tensor([[[[  14.2684],
             [   7.6098],
             [  -8.5611],
             [   4.7561],
             [   2.8537]],

            [[  29.0738],
             [  15.8837],
             [ -18.3921],
             [  16.2777],
             [   4.3295]],

            [[  47.0665],
             [  24.8722],
             [ -21.2643],
             [  27.3841],
             [   1.6367]],

            [[  60.1716],
             [  31.2668],
             [ -20.4412],
             [  42.7935],
             [  -3.8709]],

            [[  78.3402],
             [  36.3744],
             [ -12.2705],
             [  56.7589],
             [ -12.6570]],

            [[  97.9722],
             [  35.9824],
             [   6.4595],
             [  72.3194],
             [ -23.9871]],

            [[ 122.4474],
             [  25.6318],
             [  43.7905],
             [  87.8930],
             [ -35.0216]],

            [[ 155.6985],
             [  -2.5386],
             [ 106.7561],
             [ 108.2489],
             [ -39.8943]],

            [[ 260.2173],
             [-113.2995],
             [ 253.6342],
             [ 148.9393],
             [-112.8409]],

            [[ 408.3958],
             [-380.5690],
             [ 546.8139],
             [ 239.1382],
             [-332.5651]]]], grad_fn=<CopySlices>),
   'irec': tensor([[[[   0.0000],
             [   0.0000],
             [   0.0000],
             [   0.0000],
             [   0.0000]],

            [[ -13.7040],
             [  -5.9117],
             [   6.2260],
             [   2.3561],
             [  -4.3022]],

            [[ -24.5942],
             [ -12.7363],
             [  20.0376],
             [  -2.4895],
             [ -12.6089]],

            [[ -41.8099],
             [ -22.0023],
             [  30.7751],
             [  -1.3965],
             [ -18.7061]],

            [[ -48.8148],
             [ -30.0275],
             [  46.5415],
             [  -7.1246],
             [ -25.4350]],

            [[ -59.3449],
             [ -41.5472],
             [  64.0612],
             [  -8.7316],
             [ -30.5599]],

            [[ -64.2468],
             [ -58.0363],
             [  90.5762],
             [ -11.9200],
             [ -33.8301]],

            [[ -64.7661],
             [ -83.3006],
             [ 125.4391],
             [  -9.0941],
             [ -29.9181]],

            [[   0.8605],
             [-176.5699],
             [ 222.8821],
             [   9.3266],
             [-104.7321]],

            [[  42.1174],
             [-351.7816],
             [ 389.2155],
             [  59.4598],
             [-264.7752]]]], grad_fn=<CopySlices>),
   'spikes': tensor([[[ 14.,   7.,   0.,   4.,   2.],
            [ 29.,  16.,   0.,  16.,   5.],
            [ 47.,  25.,   0.,  28.,   1.],
            [ 60.,  31.,   0.,  43.,   0.],
            [ 78.,  36.,   0.,  56.,   0.],
            [ 98.,  36.,   0.,  73.,   0.],
            [123.,  26.,   0.,  88.,   0.],
            [155.,   0.,  90., 108.,   0.],
            [261.,   0., 254., 149.,   0.],
            [408.,   0., 547., 239.,   0.]]], grad_fn=<CopySlices>)},
  '1_LIFTorch_output': tensor([[[ 14.,   7.,   0.,   4.,   2.],
           [ 29.,  16.,   0.,  16.,   5.],
           [ 47.,  25.,   0.,  28.,   1.],
           [ 60.,  31.,   0.,  43.,   0.],
           [ 78.,  36.,   0.,  56.,   0.],
           [ 98.,  36.,   0.,  73.,   0.],
           [123.,  26.,   0.,  88.,   0.],
           [155.,   0.,  90., 108.,   0.],
           [261.,   0., 254., 149.,   0.],
           [408.,   0., 547., 239.,   0.]]], requires_grad=True)},
 '2_TorchResidual_output': tensor([[[ 14.,  71.,   0.,   4.,  19.],
          [ 29., 142.,   0.,  16.,  38.],
          [ 47., 210.,   0.,  28.,  50.],
          [ 60., 271.,   0.,  43.,  64.],
          [ 78., 330.,   0.,  56.,  77.],
          [ 98., 380.,   0.,  73.,  91.],
          [123., 417.,   0.,  88., 104.],
          [155., 437.,  90., 108., 116.],
          [261., 481., 254., 149., 127.],
          [408., 522., 547., 239., 138.]]], requires_grad=True),
 '3_LinearTorch': {},
 '3_LinearTorch_output': tensor([[[   8.,   -5.,   11.],
          [  18.,   -9.,   22.],
          [  26.,  -12.,   31.],
          [  35.,  -14.,   40.],
          [  43.,  -17.,   48.],
          [  49.,  -19.,   55.],
          [  52.,  -20.,   60.],
          [  51.,  -36.,   73.],
          [  45.,  -65.,   95.],
          [  37., -111.,  127.]]], requires_grad=True),
 '4_LIFTorch': {'vmem': tensor([[[ 6.0984e-01, -4.7561e+00,  4.6352e-01],
           [ 9.4092e-01, -1.7609e+01,  3.2118e-01],
           [ 7.9974e-01, -4.0612e+01,  1.6784e-01],
           [ 6.2214e-01, -7.4647e+01,  2.0042e-01],
           [ 5.0996e-01, -1.2144e+02,  3.0400e-01],
           [ 3.6009e-01, -1.8156e+02,  1.3060e-01],
           [ 3.0579e-02, -2.5455e+02,  2.1970e-01],
           [ 5.3970e-01, -3.5423e+02,  5.7108e-01],
           [ 2.4579e-01, -5.0542e+02,  2.5729e-01],
           [ 5.5481e-02, -7.4660e+02,  2.6428e-02]]], grad_fn=<CopySlices>),
  'isyn': tensor([[[[   7.6098],
            [  -4.7561],
            [  10.4635]],

           [[  24.3608],
            [ -13.0853],
            [  30.8803]],

           [[  47.9047],
            [ -23.8618],
            [  58.8623]],

           [[  78.8614],
            [ -36.0153],
            [  94.0408]],

           [[ 115.9182],
            [ -50.4297],
            [ 135.1134]],

           [[ 156.8750],
            [ -66.0436],
            [ 180.8414]],

           [[ 198.6880],
            [ -81.8472],
            [ 229.0955]],

           [[ 237.5106],
            [-112.0997],
            [ 287.3621]],

           [[ 268.7324],
            [-168.4624],
            [ 363.7141]],

           [[ 290.8217],
            [-265.8329],
            [ 466.7817]]]], grad_fn=<CopySlices>),
  'irec': tensor([[[[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]],

           [[0.],
            [0.],
            [0.]]]]),
  'spikes': tensor([[[  7.,   0.,  10.],
           [ 24.,   0.,  31.],
           [ 48.,   0.,  59.],
           [ 79.,   0.,  94.],
           [116.,   0., 135.],
           [157.,   0., 181.],
           [199.,   0., 229.],
           [237.,   0., 287.],
           [269.,   0., 364.],
           [291.,   0., 467.]]], grad_fn=<CopySlices>)},
 '4_LIFTorch_output': tensor([[[  7.,   0.,  10.],
          [ 24.,   0.,  31.],
          [ 48.,   0.,  59.],
          [ 79.,   0.,  94.],
          [116.,   0., 135.],
          [157.,   0., 181.],
          [199.,   0., 229.],
          [237.,   0., 287.],
          [269.,   0., 364.],
          [291.,   0., 467.]]], requires_grad=True)}

As expected, the outputs of the Linear layers are now signed 8-bit integers, maintained as floating-point representation.

Decay transformations

In case of training decays, decay parameter of LIF neurons \(\exp{(-dt/\tau)}\) can be quantized to match the way that decay is implemented in Xylo:

bitshift subtraction: \(V_{mem}~~ \rightarrow V_{mem} \cdot (1- \frac{1}{2^N})\)

[20]:
# - Build a network to use
# activate the decay training for the last layer
net_decay = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3, leak_mode="decays"),
)
net_decay
[20]:
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}
[21]:
tconfig_decay = tt.make_param_T_config(net_decay, lambda p: tt.t_decay(p), "decays")
print(tconfig_decay["4_LIFTorch"])
{'alpha': <function <lambda> at 0x7f8ed3137ca0>, 'beta': <function <lambda> at 0x7f8ed3137ca0>}
[22]:
t_net_decay = tt.make_param_T_network(net_decay, tconfig_decay)
print(t_net_decay)
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    TWrapper '4_LIFTorch' with shape (3, 3) {
        LIFTorch '_mod' with shape (3, 3)
    }
}

Building a network with bitshift decays

in case of passing BitShift_training=True to the LIF neurons the membrane and synaptic decays will be directly applied based on bitshift subtraction. For quantization its enough to round them.

[25]:
# - Build a network to use
# activate the decay training for the last layer
net_bitshift = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3, leak_mode="bitshifts"),
)
net_bitshift
[25]:
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}
[26]:
tconfig_bitshift = tt.make_param_T_config(
    net_bitshift, lambda p: tt.round_passthrough(p), "bitshifts"
)
print(tconfig_bitshift["4_LIFTorch"])
{'dash_mem': <function <lambda> at 0x7f8ed3063700>, 'dash_syn': <function <lambda> at 0x7f8ed3063700>}
[27]:
t_net_bitshift = tt.make_param_T_network(net_bitshift, tconfig_bitshift)
print(t_net_bitshift)
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    TWrapper '4_LIFTorch' with shape (3, 3) {
        LIFTorch '_mod' with shape (3, 3)
    }
}