"""
Defines the parameter and activation transformation-in-training pipeline for `TorchModule` s
See Also:
:ref:`/advanced/QuantTorch.ipynb`
Examples:
Construct a network, and patch it to round each weight parameter:
>>> net = Sequential(...)
>>> T_fn = lambda p: stochastic_rounding(p, num_levels = 2**num_bits)
>>> T_config = make_param_T_config(net, T_fn, 'weights')
>>> T_net = make_param_T_network(net, T_config)
Train here. To burn-in and remove the transformations:
>>> burned_in_net = apply_T(T_net)
>>> unpatched_net = remove_T_net(burned_in_net)
"""
from rockpool.utilities.backend_management import torch_version_satisfied
if not torch_version_satisfied(1, 12):
raise ModuleNotFoundError(
"torch version 1.12.0 or greater is required. The `torch_transform` package is not available."
)
import torch
from rockpool.nn.modules.module import Module, ModuleBase
from rockpool.nn.modules.torch.torch_module import TorchModule
from rockpool.typehints import Tensor, P_Callable, Tree
from typing import Optional, Tuple, List, Callable, Dict
from rockpool.graph import GraphModuleBase
import copy
import rockpool.utilities.tree_utils as tu
__all__ = [
"stochastic_rounding",
"stochastic_channel_rounding",
"deterministic_rounding",
"dropout",
"make_param_T_config",
"make_param_T_network",
"make_backward_passthrough",
"make_act_T_config",
"make_act_T_network",
"int_quant",
"t_decay",
]
[docs]def make_backward_passthrough(function: Callable) -> Callable:
"""
Wrap a function to pass the gradient directly through in the backward pass
Args:
function (Callable): A function to wrap
Returns:
Callable: A function wrapped in the backward pass
"""
class Wrapper(torch.autograd.Function):
"""A torch.autograd.Function that wraps a function with a pass-through gradient."""
@staticmethod
def forward(self, x):
self.save_for_backward(x)
return function(x)
@staticmethod
def backward(self, grad_output):
grad_x = grad_output if self.needs_input_grad[0] else None
return grad_x
return Wrapper.apply
# - Make a passthrough version of the floor function
floor_passthrough = make_backward_passthrough(torch.floor)
round_passthrough = make_backward_passthrough(torch.round)
[docs]def int_quant(
value: Tensor,
maintain_zero: bool = True,
map_to_max: bool = False,
n_bits: int = 8,
max_n_bits: int = 8,
):
"""
Transforms a tensor to a quantized space with a range of integer values defined by n_bits
Examples
>>> int_quant(torch.tensor([-1, -0.50, 0.25, 0, 0.25, 0.5, 1]), n_bits = 3, map_to_max = False)
tensor([-3., -2., -1., 0., 1., 2., 3.])
>>> int_quant(torch.tensor([-1, -0.50, 0.25, 0, 0.25, 0.5, 1]), n_bits = 3, map_to_max = True)
tensor([-127., -64., 32., 0., 32., 64., 127.])
Args:
value (torch.Tensor): A Tensor of values to be quantized
maintain_zero (bool): If ``True``, ensure that input values of zero map to zero in the output space (Default: ``True``). If ``False``, the output range may shift zero w.r.t. the input range.
map_to_max (bool): If True the integer values in the quantized output are mapped to the extremes of bitdepth
nbits (int): defines the maximum integer in the quantized tensor
max_n_bits (int): maximum allowed bitdepth
Returns: torch.tensor: (q_value) quantized values
"""
if not maintain_zero:
l = value.max() - value.min()
value = value - (l / 2)
max_value = torch.max(torch.abs(value))
max_value_quant = 2 ** (n_bits - 1) - 1
if max_value != 0:
scale = max_value_quant / max_value
else:
scale = 1
if map_to_max:
max_range = 2 ** (max_n_bits - 1) - 1
scale *= max_range / max_value_quant
q_value = round_passthrough(scale * value)
return q_value
[docs]def stochastic_rounding(
value: Tensor,
input_range: Optional[List] = None,
output_range: Optional[List] = None,
num_levels: int = 2**8,
maintain_zero: bool = True,
):
"""
Perform floating-point stochastic rounding on a tensor, with detailed control over quantisation levels
Stochastic rounding randomly pushes values up or down probabilistically, depending on the original value.
Values will round with greater probability to their nearest quantised level, and with lower probability to their next-nearest quantised level.
For example, if we are rounding to integers, then a value of 0.1 will round down to 0.0 with 90% probability; it will round to 1.0 with 10% probability.
If we are rounding to arbitrary floating point levels, then the same logic holds, but the quantised output values will not be round numbers.
:py:func:`stochastic_rounding` permits the input space to be re-scaled during rounding to an output space, which will be quantised over a specified number of quantisation levels (``2**8`` by default).
By default, the input and output space are defined to be the full input range from minimum to maximum.
:py:func:`stochastic_rounding` permits careful handling of symmetric or asymmetric spaces.
By default, values of zero in the input space will map to zero in the output space (i.e. ``maintain_zero = True``).
In this case the output range is defined as ``max(abs(input_range)) * [-1, 1]``.
Quantisation and rounding is always to equally-spaced levels.
Examples
>>> stochastic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1., 0., 0., 1., 1.])
>>> stochastic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1., -1., 0., 0., 1.])
Quantise to round integers over a defined space, changing the input scale (0..1) to (0..10).
>>> stochastic_rounding(torch.rand(10), input_range = [0., 1.], output_range = [0., 10.], num_levels = 11)
tensor([1., 9., 2., 2., 7., 1., 7., 9., 6., 1.])
Quantise to floating point levels, without changing the scale of the values.
>>> stochastic_rounding(torch.rand(10)-.5, num_levels = 3)
tensor([ 0.0000, 0.0000, 0.0000, -0.4701, -0.4701, 0.4701, -0.4701, -0.4701, 0.0000, 0.4701])
>>> stochastic_rounding(torch.rand(10)-.5, num_levels = 3)
tensor([ 0.0000, 0.0000, -0.4316, 0.0000, 0.0000, 0.4316, -0.4316, 0.0000, 0.4316, 0.4316])
Note that the scale is defined by the observed range of the random values, in this case.
Args:
value (torch.Tensor): A Tensor of values to round
input_range (Optional[List]): If defined, a specific input range to use (``[min_value, max_value]``), as floating point numbers. If ``None`` (default), use the range of input values to define the input range.
output_range (Optional[List]): If defined, a specific output range to use (``[min_value, max_value]``), as floating point numbers. If ``None`` (default), use the range of input values to define the input range.
num_levels (int): The number of output levels to quantise to (Default: ``2**8``)
maintain_zero (bool): If ``True``, ensure that input values of zero map to zero in the output space (Default: ``True``). If ``False``, the output range may shift zero w.r.t. the input range.
Returns:
torch.Tensor: Floating point stochastically rounded values
"""
if maintain_zero:
# - By default, input range is whatever the current data range is
max_range = torch.max(torch.abs(value))
input_range = (
[-max_range, max_range] if input_range is None else list(input_range)
)
# - By default, the output range is the same as the input range
output_range = input_range if output_range is None else list(output_range)
else:
# - By default, input range is whatever the current data range is
input_range = (
[torch.min(value), torch.max(value)]
if input_range is None
else list(input_range)
)
# - By default, the output range is the same as the input range
output_range = input_range if output_range is None else list(output_range)
# - Compute input and output quanta
input_quantum = (input_range[1] - input_range[0]) / (num_levels - 1)
output_quantum = (output_range[1] - output_range[0]) / (num_levels - 1)
# - Perform quantisation
levels = (value - input_range[0]) / input_quantum
levels_floor = floor_passthrough(levels)
levels_round = levels_floor + (
(levels - levels_floor) > torch.rand(*value.shape).to(levels_floor.device)
)
output_param = levels_round * output_quantum + output_range[0]
return output_param
[docs]def stochastic_channel_rounding(
value: Tensor,
output_range: List[float],
num_levels: int = 2**8,
maintain_zero: bool = True,
):
"""
Perform stochastic rounding of a matrix, but with the input range defined automatically for each column independently
This function performs the same quantisation approach as :py:func:`stochastic_rounding`, but considering each column of a matrix independently. ie. per-channel.
Args:
value (torch.Tensor): A tensor of values to quantise
output_range (List[float]): Defines the destination quantisation space ``[min_value, max_value]``.
num_levels (int): The number of quantisation values to round to (Default: ``2**8``)
maintain_zero (bool): Iff ``True`` (default), input values of zero map to zero in the output space. If ``False``, the output space may shift w.r.t. the input space. Note that the output space must be symmetric for the zero mapping to work as expected.
Returns:
torch.Tensor: The rounded values
"""
def round_vector(vector: Tensor):
if maintain_zero:
max_range = torch.max(torch.abs(vector))
input_range = [-max_range, max_range]
else:
input_range = [torch.min(vector), torch.max(vector)]
# - Compute input and output quanta
input_quantum = (input_range[1] - input_range[0]) / (num_levels - 1)
output_quantum = (output_range[1] - output_range[0]) / (num_levels - 1)
# - Perform quantisation
levels = (vector - input_range[0]) / input_quantum
levels_floor = floor_passthrough(levels)
levels_round = levels_floor + (
(levels - levels_floor) > torch.rand(*vector.shape)
)
output_param = levels_round * output_quantum + output_range[0]
return output_param
for i in range(value.shape[1]):
value[:, i] = round_vector(value[:, i])
return value
[docs]def deterministic_rounding(
value: Tensor,
input_range: Optional[List] = None,
output_range: Optional[List] = None,
num_levels: int = 2**8,
maintain_zero: bool = True,
):
"""
Quantise values by shifting them to the closest quantisation level
This is a floating-point equivalent to standard integer rounding (e.g. using ``torch.round()``).
:py:func:`deterministic_rounding` provides fine control over input and output spaces, as well as numbers of levels to quantise to, and can round to floating point levels instead of round numbers.
:py:func:`deterministic_rounding` always leaves values as floating point.
For example, if we are rounding to integers, then a value of ``0.1`` will round down to ``0.``.
A value of ``0.5`` will round up to ``1.``.
A value of ``0.9`` will round up to ``1.``.
If we are rounding to arbitrary floating point levels, then the same logic holds, but the quantised output values will not be round numbers, but will be the nearest floating point quantisation level.
:py:func:`deterministic_rounding` permits the input space to be re-scaled during rounding to an output space, which will be quantised over a specified number of quantisation levels (``2**8`` by default).
By default, the input and output space are defined to be the full input range from minimum to maximum.
:py:func:`deterministic_rounding` permits careful handling of symmetric or asymmetric spaces.
By default, values of zero in the input space will map to zero in the output space (i.e. ``maintain_zero = True``).
In this case the output range is defined as ``max(abs(input_range)) * [-1, 1]``.
Quantisation and rounding is always to equally-spaced levels.
Examples
>>> deterministic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1., -1., 0., 1., 1.])
>>> deterministic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1., -1., 0., 0., 1.])
Round to integer values (-10..10)
>>> deterministic_rounding(torch.rand(10)-.5, output_range=[-10., 10.], num_levels = 21)
tensor([ 10., -3., 3., -10., 5., 0., 9., -5., 7., 8.])
value (torch.Tensor): A Tensor of values to round
input_range (Optional[List]): If defined, a specific input range to use (``[min_value, max_value]``), as floating point numbers. If ``None`` (default), use the range of input values to define the input range.
output_range (Optional[List]): If defined, a specific output range to use (``[min_value, max_value]``), as floating point numbers. If ``None`` (default), use the range of input values to define the input range.
num_levels (int): The number of output levels to quantise to (Default: ``2**8``)
maintain_zero (bool): Iff ``True``, ensure that input values of zero map to zero in the output space (Default: ``True``). If ``False``, the output range may shift zero w.r.t. the input range.
Returns:
torch.Tensor: Floating point rounded values
"""
if maintain_zero:
# - By default, input range is whatever the current data range is
max_range = torch.max(torch.abs(value))
input_range = (
[-max_range, max_range] if input_range is None else list(input_range)
)
# - By default, the output range is the same as the input range
output_range = input_range if output_range is None else list(output_range)
else:
# - By default, input range is whatever the current data range is
input_range = (
[torch.min(value), torch.max(value)]
if input_range is None
else list(input_range)
)
# - By default, the output range is the same as the input range
output_range = input_range if output_range is None else list(output_range)
# - Compute input and output quanta
input_quantum = (input_range[1] - input_range[0]) / (num_levels - 1)
output_quantum = (output_range[1] - output_range[0]) / (num_levels - 1)
# - Perform quantisation
levels_round = round_passthrough((value - input_range[0]) / input_quantum)
output_param = levels_round * output_quantum + output_range[0]
return output_param
[docs]def dropout(param: Tensor, dropout_prob: float = 0.5):
"""
Randomly set values of a tensor to ``0.``, with a defined probability
Dropout is used to improve the robustness of a network, by reducing the dependency of a network on any given parameter value.
This is accomplished by randomly setting parameters (usually weights) to zero during training, such that the parameters are ignored.
For a ``dropout_prob = 0.8``, each parameters is randomly set to zero with 80\% probability.
Examples:
>>> dropout(torch.ones(10))
tensor([0., 0., 0., 1., 0., 1., 1., 1., 1., 0.])
>>> dropout(torch.ones(10), dropout_prob = 0.8)
tensor([1., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
Args:
param (torch.Tensor): A tensor of values to dropout
dropout_prob (float): The probability of zeroing each parameter value (Default: ``0.5``, 50\%)
Returns:
torch.Tensor: The tensor of values, with elements dropped out probabilistically
"""
mask = torch.rand(param.shape, device=param.device) > dropout_prob
return param * mask
[docs]class TWrapper(TorchModule):
"""
A wrapper for a Rockpool TorchModule, implementing a parameter transformation in the forward pass
This module is not designed to be be user-facing; you should probably use the helper functions :py:func:`make_param_T_config` and :py:func:`make_param_T_network` to patch a Rockpool network. This will insert :py:class:`TWrapper` modules into the network architecture as required.
See Also:
:ref:`/advanced/QuantTorch.ipynb`
"""
[docs] def __init__(
self, mod: TorchModule, T_config: Optional[Tree] = None, *args, **kwargs
):
"""
Initialise a parameter transformer wrapper module
``mod`` is a Rockpool module with some set of attributes. ``T_config`` is a dictionary, with keys optionally matching the attributes of ``mod``. Each value must be a callable ``T_Fn(a) -> a`` which can transform the associated attribute ``a``.
A :py:class:`.TWrapper` module will be created, with ``mod`` as a sub-module. The :py:class:`.TWrapper` will apply the specified transformations to all the attributes of ``mod`` at the beginning of the forward-pass of evolution, then evolve ``mod`` with the transformed attributes.
Users should use the helper functions :py:func:`.make_param_T_config` and :py:func:`.make_param_T_network`.
See Also:
:ref:`/advanced/QuantTorch.ipynb`
Args:
mod (TorchModule): A Rockpool module to apply parameter transformations to
T_config (Optional[Tree]): A nested dictionary specifying which transofmration transformations to apply to specific parameters. Each transformation function must be specified as a Callable with a key identical to a parameter of ``mod``. If ``None``, do not apply any transformation to ``mod``.
"""
# - Initialise Module superclass
super().__init__(*args, shape=mod.shape, **kwargs)
# - Copy module attributes
self._name = mod._name
self._mod = mod
self._spiking_input = mod.spiking_input
self._spiking_output = mod.spiking_output
# - Default: null transformation config
self._T_config = {} if T_config is None else T_config
[docs] def forward(self, *args, **kwargs):
# - Get transformed attributes
transformed_attrs = self._T()
if self._record and not self._has_torch_api:
kwargs["record"] = self._record
# - Call module with torch functional API
out = torch.nn.utils.stateless.functional_call(
self._mod, transformed_attrs, args, kwargs
)
if not self._has_torch_api:
self._record_dict = out[2]
elif hasattr(self._mod, "_record_dict"):
self._record_dict = self._mod._record_dict
if not self._has_torch_api:
return out[0]
else:
return out
[docs] def as_graph(self) -> GraphModuleBase:
return self._mod.as_graph()
def _T(self):
# - Transform parameters
return {
k: (
T_fn(getattr(self._mod, k))
if T_fn is not None
else getattr(self._mod, k)
)
for k, T_fn in self._T_config.items()
}
def apply_T(self, inplace: bool = True) -> TorchModule:
# - Get transformed attributes
transformed_attrs = self._T()
self._mod = self._mod.set_attributes(transformed_attrs)
return self
[docs]def make_param_T_config(
net: ModuleBase, T_fn: Callable, param_family: Optional[str] = None
) -> Tree:
"""
Helper function to build parameter transformation configuration trees
This function builds a parameter transformation nested configuration tree, based on an existing network ``net``.
You can use :py:func:`.tree_utils.tree_update` to merge two configuration trees for different parameter families.
The resulting configuration tree can be passed to :py:func:`.make_param_T_network` to patch the network ``net``.
Examples:
>>> T_config = make_param_T_config(net, lambda p: p**2, 'weights')
>>> T_net = make_param_T_network(net, T_config)
Args:
net (Module): A Rockpool network to use as a template for the transformation configuration tree
T_fn (Callable): A transformation function to apply to a parameter. Must have the signature ``T_fn(x) -> x``.
param_family (Optional[str]): An optional argument to specify a parameter family. Only parameters matching this family within ``net`` will be specified in the configuration tree.
"""
return tu.tree_map(net.parameters(param_family), lambda _: T_fn)
[docs]def make_param_T_network(
net: ModuleBase, T_config_tree: Tree, inplace: bool = False
) -> TorchModule:
"""
Patch a Rockpool network to apply parameter transformations in the forward pass
This helper function inserts :py:class:`.TWrapper` modules into the network tree, where required, to apply transformations to each module as defined by a configuration tree ``T_config_tree``. Use the helper function :py:func:`.make_param_T_config` to build configuration trees.
The resulting network will have analogous structure and behaviour to the original network, but the transformations will be applied before the forward pass of each module.
Network parameters will remain "held" by the original modules, un-transformed.
You can use the :py:func:`.remove_T_net` function to undo this patching behaviour, restoring the original network structure but keeping any parameter modifications (e.g. training) in place.
You can use the :py:func:`.apply_T` function to "burn in" the parameter transformation.
Args:
net (Module): A Rockpool network to use as a template for the transformation configuration tree
T_config_tree (Tree): A nested dictionary, mimicing the structure of ``net``, specifying which parameters should be transformed and which transformation function to apply to each parameter.
inplace (bool): If ``False`` (default), a deep copy of ``net`` will be created, transformed and returned. If ``True``, the network will be patched in place.
"""
if not inplace:
net = copy.deepcopy(net)
if len(net.modules()) == 0:
# - Patch a single module
net = TWrapper(net, T_config_tree)
else:
# - Patch a network tree or sub-tree
# Get a list of sub-modules
_, modules = net._get_attribute_registry()
for k, mod in modules.items():
# - If there are transformations specified for this module
if k in T_config_tree:
# - Then recurse to patch the module
setattr(
net,
k,
make_param_T_network(mod[0], T_config_tree[k], inplace=inplace),
)
return net
[docs]def apply_T(T_net: TorchModule, inplace: bool = False) -> TorchModule:
"""
"Burn in" a set of parameter transformations, applying each transformation and storing the resulting transformed parameters
This function takes a transformer-patched network ``net``, and applies the pre-specified transformations to each parameter. The resulting transformed parameters are then stored within the parameters of the network.
This is a useful step **after** training, as part of extracting the transformed parameters from the trained network.
The helper function :py:func:`.remove_T_net` can be used afterwards to remove the transformer patches from the network.
Examples:
>>> T_net = make_param_T_network(net, T_config)
At this point, ``T_net.parameters()`` will return un-transformed parameters.
>>> T_net.apply_T()
Now ``T_net.parameters()`` will contain the results of applying the transformation to each parameter.
Args:
T_net (TorchModule): A transformer-patched network, obtained with :py:func:`make_param_T_network`
inplace (bool): If ``False`` (default), a deep copy of ``net`` will be created, transformed and returned. If ``True``, the network will be transformed in place.
"""
if not inplace:
T_net = copy.deepcopy(T_net)
if isinstance(T_net, TWrapper):
T_net = T_net.apply_T(inplace=inplace)
_, modules = T_net._get_attribute_registry()
if len(T_net.modules()) > 0:
for k, mod in modules.items():
setattr(T_net, k, apply_T(mod[0], inplace=inplace))
return T_net
[docs]def remove_T_net(T_net: TorchModule, inplace: bool = False) -> TorchModule:
"""
Un-patch a transformed-patched network
This function will iterate through a network patched using :py:func:`make_param_T_network`, and remove all patching modules. The resulting network should have the same network architecture as the original un-patched network.
Any parameter values applied within ``T_net`` will be retained in the unpatched network.
Args:
T_net (TorchModule): A transformer-patched network, obtained with :py:func:`make_param_T_network`
inplace (bool): If ``False`` (default), a deep copy of ``net`` will be created, transformed and returned. If ``True``, the network will be un-patched in place. Warning: in-place operation cannot work for single instances of :py:class:`TWrapper`
Returns:
TorchModule: A network matching ``T_net``, but with transformers removed.
"""
if not inplace:
T_net = copy.deepcopy(T_net)
if isinstance(T_net, (TWrapper, ActWrapper)):
T_net._mod._name = T_net.name
T_net = T_net._mod
else:
_, modules = T_net._get_attribute_registry()
for k, mod in modules.items():
setattr(T_net, k, remove_T_net(mod[0], inplace=inplace))
return T_net
[docs]class ActWrapper(TorchModule):
"""
A wrapper module that applies an output activity transformation after evolution
This module is not designed to be user-facing. Users should use the helper functions :py:func:`make_act_T_config` and :py:func:`make_act_T_network`. This approach will insert :py:class:`ActWrapper` modules into the network as required.
See Also:
:ref:`/advanced/QuantTorch.ipynb`
"""
[docs] def __init__(
self,
mod: TorchModule,
trans_Fn: Optional[Callable] = None,
*args,
**kwargs,
):
"""
Instantiate an ActWrapper object
``mod`` is a Rockpool module. An :py:class:`ActWrapper` will be created to wrap ``mod``. The transformation function ``trans_Fn`` will be applied to the outputs of ``mod`` during evolution.
See Also:
:ref:`/advanced/QuantTorch.ipynb`
Args:
mod (TorchModule): A module to patch
trans_Fn (Optional(Callable)): A transformation function to apply to the outputs of ``mod``. If ``None``, no transformation will be applied.
"""
# - Initialise superclass
super().__init__(*args, shape=mod.shape, **kwargs)
# - Record module attributes
self._name = mod._name
self._mod = mod
self._spiking_input = mod.spiking_input
self._spiking_output = mod.spiking_output
# - Record transformation function
self._trans_Fn = (lambda x: x) if trans_Fn is None else trans_Fn
[docs] def forward(self, *args, **kwargs):
out = self._mod(*args, **kwargs)
if self._mod._has_torch_api:
return self._trans_Fn(out)
else:
return self._trans_Fn(out[0])
[docs] def as_graph(self) -> GraphModuleBase:
return self._mod.as_graph()
[docs]def make_act_T_config(
net: TorchModule,
T_fn: Optional[Callable] = None,
ModuleClass: Optional[type] = None,
) -> Tree:
"""
Create an activity transformation configuration for a network
This helper function assists in defining an activity transformation configuration tree. It allows to to search a predefined network ``net`` to find modules of a chosen class, and specify an activity transformation for those modules.
``net`` is a pre-defined Rockpool network.
``T_fn`` is a `Callable` with signature ``f(x) -> x``, transforming the output ``x`` of a module.
``ModuleClass`` optionally specifies the class of module to search for in ``net``.
You can use :py:func:`.tree_utils.tree_update` to merge two configuration trees for different parameter families.
The resulting configuration tree will match the structure of ``net`` (or will be a sub-tree of ``net``, including modules of type ``ModuleClass``). You can pass this configuration tree to :py:func:`.make_act_T_network` to build an activity transformer tree.
Args:
net (TorchModule): A Rockpool network to build a configuration tree for
T_fn (Callable): A function ``f(x) -> x`` to apply as a transformation to module output. If ``None``, no transformation will be applied.
``ModuleClass`` (Optional[type]): A :py:class:`~.modules.Module` subclass to search for. The configuration tree will include only modules matching this class.
Returns:
Tree: An activity transformer configuration tree, to pass to :py:func:`.make_act_T_network`
"""
# - Define a transformation function for this module, optionally matching a Module class
if ModuleClass is not None:
act_T_config = {"": T_fn} if isinstance(net, ModuleClass) else {"": None}
else:
act_T_config = {"": T_fn}
if len(net.modules()) > 0:
for k, mod in net.modules().items():
act_T_config[k] = make_act_T_config(mod, T_fn, ModuleClass)
return act_T_config
[docs]def make_act_T_network(
net: TorchModule, act_T_config: Tree, inplace: bool = False
) -> TorchModule:
"""
Patch a Rockpool network with activity transformers
This helper function inserts :py:class:`ActWrapper` modules into a pre-defined network ``net``, to apply an activity transformation configuration ``act_T_config``.
Args:
net (TorchModule): A Rockpool network to patch
act_T_config (Tree): A configuration tree from :py:func:`make_act_T_config`
inplace (bool): If ``False`` (default), create a deep copy of ``net`` to patch. If ``True``, patch the network in place. This in place operation does not work when patching single modules.
Returns:
TorchModule: The patched network
"""
if not inplace:
net = copy.deepcopy(net)
if len(net.modules()) == 0:
if "" in act_T_config and act_T_config[""] is not None:
net = ActWrapper(net, act_T_config[""])
else:
_, modules = net._get_attribute_registry()
for k, mod in modules.items():
if k in act_T_config:
setattr(net, k, make_act_T_network(mod[0], act_T_config[k]))
return net
[docs]class class_calc_q_decay:
"""
function used to calculate bitshift equivalent of decay (\exp(-dt/tau))
Args:
decay (torch.Tensor) : alpha and beta parameters from decay parameter family of LIF neurons
dt (float) : Euler simulator time-step in seconds
Returns:
q_decay (torch.Tensor)
"""
[docs] def __init__(self, dt):
self.dt = dt
def calc_bitshift_decay(self, decay):
tau = -self.dt / torch.log(decay)
bitsh = torch.round(torch.log2(tau / self.dt)).int()
bitsh[bitsh < 0] = 0
q_alpha = torch.tensor(1 - (1 / (2**bitsh)))
return q_alpha
def __call__(self, decay):
return self.calc_bitshift_decay(decay)
[docs]def t_decay(decay: Tensor, dt: float = 1e-3):
"""
quantizes decay factor (\exp (-dt/tau)) of LIF neurons: alpha and beta respectively for Vmem and Isyn
the quantization is done based one converting the decay to bitshoft subtraction and reconstructing decay.
the trasnformation is passed to make_backward_passthrough function that applies it in the forward pass of Torch module
note: this transform can be used only if decay_training = True atleast one of LIF modules in teh network
Examples:
applying on tensors:
>>> alpha = torch.rand(10)
tensor([0.4156, 0.8306, 0.1557, 0.9475, 0.4532, 0.3585, 0.4014, 0.9380, 0.9426, 0.7212])
>>> tt.t_decay(alpha)
tensor([0.0000, 0.7500, 0.0000, 0.9375, 0.0000, 0.0000, 0.0000, 0.9375, 0.9375,
0.7500])
applying to decay parameter family of a network:
>>> net = Sequential(LinearTorch((2,2)), LIFTorch(2, decay_training=True))
>>> tconfig = tt.make_param_T_config(net, lambda p : tt.t_decay(p), 'decays')
: {'1_LIFTorch': {'alpha': <function <lambda> at 0x7fb32efbb280>, 'beta': <function <lambda> at 0x7fb32efbb280>}}
>>> T_net = make_param_T_network(net, T_config)
: TorchSequential with shape (2, 2) {
LinearTorch '0_LinearTorch' with shape (2, 2)
TWrapper '1_LIFTorch' with shape (2, 2) {
LIFTorch '_mod' with shape (2, 2) }}
Args:
decay (torch.Tensor) : alpha and beta parameters from decay parameter family of LIF neurons
dt (float) : Euler simulator time-step in seconds
Returns:
q_decay (torch.Tensor)
"""
fn = class_calc_q_decay(dt=dt)
decay_passthrough = make_backward_passthrough(fn)
q_decay = decay_passthrough(decay)
return q_decay