Module transform.torch_transform

Defines the parameter and activation transformation-in-training pipeline for TorchModule s

Examples

Construct a network, and patch it to round each weight parameter:

>>> net = Sequential(...)
>>> T_fn = lambda p: stochastic_rounding(p, num_levels = 2**num_bits)
>>> T_config = make_param_T_config(net, T_fn, 'weights')
>>> T_net = make_param_T_network(net, T_config)

Train here. To burn-in and remove the transformations:

>>> burned_in_net = apply_T(T_net)
>>> unpatched_net = remove_T_net(burned_in_net)

Functions overview

apply_T(T_net[, inplace])

"Burn in" a set of parameter transformations, applying each transformation and storing the resulting transformed parameters

deterministic_rounding(value[, input_range, ...])

Quantise values by shifting them to the closest quantisation level

dropout(param[, dropout_prob])

Randomly set values of a tensor to 0., with a defined probability

int_quant(value[, maintain_zero, ...])

Transforms a tensor to a quantized space with a range of integer values defined by n_bits

make_act_T_config(net[, T_fn, ModuleClass])

Create an activity transformation configuration for a network

make_act_T_network(net, act_T_config[, inplace])

Patch a Rockpool network with activity transformers

make_backward_passthrough(function)

Wrap a function to pass the gradient directly through in the backward pass

make_param_T_config(net, T_fn[, param_family])

Helper function to build parameter transformation configuration trees

make_param_T_network(net, T_config_tree[, ...])

Patch a Rockpool network to apply parameter transformations in the forward pass

remove_T_net(T_net[, inplace])

Un-patch a transformed-patched network

stochastic_channel_rounding(value, output_range)

Perform stochastic rounding of a matrix, but with the input range defined automatically for each column independently

stochastic_rounding(value[, input_range, ...])

Perform floating-point stochastic rounding on a tensor, with detailed control over quantisation levels

t_decay(decay[, dt])

quantizes decay factor (exp (-dt/tau)) of LIF neurons: alpha and beta respectively for Vmem and Isyn the quantization is done based one converting the decay to bitshoft subtraction and reconstructing decay.

Classes overview

ActWrapper(*args, **kwargs)

A wrapper module that applies an output activity transformation after evolution

TWrapper(*args, **kwargs)

A wrapper for a Rockpool TorchModule, implementing a parameter transformation in the forward pass

class_calc_q_decay(dt)

function used to calculate bitshift equivalent of decay (exp(-dt/tau))

Functions

transform.torch_transform.apply_T(T_net: TorchModule, inplace: bool = False) TorchModule[source]

“Burn in” a set of parameter transformations, applying each transformation and storing the resulting transformed parameters

This function takes a transformer-patched network net, and applies the pre-specified transformations to each parameter. The resulting transformed parameters are then stored within the parameters of the network.

This is a useful step after training, as part of extracting the transformed parameters from the trained network.

The helper function remove_T_net() can be used afterwards to remove the transformer patches from the network.

Examples

>>> T_net = make_param_T_network(net, T_config)

At this point, T_net.parameters() will return un-transformed parameters.

>>> T_net.apply_T()

Now T_net.parameters() will contain the results of applying the transformation to each parameter.

Parameters:
  • T_net (TorchModule) – A transformer-patched network, obtained with make_param_T_network()

  • inplace (bool) – If False (default), a deep copy of net will be created, transformed and returned. If True, the network will be transformed in place.

transform.torch_transform.deterministic_rounding(value: Tensor, input_range: List | None = None, output_range: List | None = None, num_levels: int = 256, maintain_zero: bool = True)[source]

Quantise values by shifting them to the closest quantisation level

This is a floating-point equivalent to standard integer rounding (e.g. using torch.round()). deterministic_rounding() provides fine control over input and output spaces, as well as numbers of levels to quantise to, and can round to floating point levels instead of round numbers. deterministic_rounding() always leaves values as floating point.

For example, if we are rounding to integers, then a value of 0.1 will round down to 0.. A value of 0.5 will round up to 1.. A value of 0.9 will round up to 1..

If we are rounding to arbitrary floating point levels, then the same logic holds, but the quantised output values will not be round numbers, but will be the nearest floating point quantisation level.

deterministic_rounding() permits the input space to be re-scaled during rounding to an output space, which will be quantised over a specified number of quantisation levels (2**8 by default). By default, the input and output space are defined to be the full input range from minimum to maximum.

deterministic_rounding() permits careful handling of symmetric or asymmetric spaces. By default, values of zero in the input space will map to zero in the output space (i.e. maintain_zero = True). In this case the output range is defined as max(abs(input_range)) * [-1, 1].

Quantisation and rounding is always to equally-spaced levels.

Examples

>>> deterministic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1.,  -1.,  0.,  1.,  1.])
>>> deterministic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1., -1.,  0.,  0.,  1.])

Round to integer values (-10..10)

>>> deterministic_rounding(torch.rand(10)-.5, output_range=[-10., 10.], num_levels = 21)
tensor([ 10.,  -3.,   3., -10.,   5.,   0.,   9.,  -5.,   7.,   8.])

value (torch.Tensor): A Tensor of values to round input_range (Optional[List]): If defined, a specific input range to use ([min_value, max_value]), as floating point numbers. If None (default), use the range of input values to define the input range. output_range (Optional[List]): If defined, a specific output range to use ([min_value, max_value]), as floating point numbers. If None (default), use the range of input values to define the input range. num_levels (int): The number of output levels to quantise to (Default: 2**8) maintain_zero (bool): Iff True, ensure that input values of zero map to zero in the output space (Default: True). If False, the output range may shift zero w.r.t. the input range.

Returns:

Floating point rounded values

Return type:

torch.Tensor

transform.torch_transform.dropout(param: Tensor, dropout_prob: float = 0.5)[source]

Randomly set values of a tensor to 0., with a defined probability

Dropout is used to improve the robustness of a network, by reducing the dependency of a network on any given parameter value. This is accomplished by randomly setting parameters (usually weights) to zero during training, such that the parameters are ignored.

For a dropout_prob = 0.8, each parameters is randomly set to zero with 80% probability.

Examples

>>> dropout(torch.ones(10))
tensor([0., 0., 0., 1., 0., 1., 1., 1., 1., 0.])
>>> dropout(torch.ones(10), dropout_prob = 0.8)
tensor([1., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
Parameters:
  • param (torch.Tensor) – A tensor of values to dropout

  • dropout_prob (float) – The probability of zeroing each parameter value (Default: 0.5, 50%)

Returns:

The tensor of values, with elements dropped out probabilistically

Return type:

torch.Tensor

transform.torch_transform.int_quant(value: Tensor, maintain_zero: bool = True, map_to_max: bool = False, n_bits: int = 8, max_n_bits: int = 8)[source]

Transforms a tensor to a quantized space with a range of integer values defined by n_bits

Examples

>>> int_quant(torch.tensor([-1, -0.50, 0.25, 0, 0.25, 0.5, 1]), n_bits = 3,  map_to_max = False)
tensor([-3., -2., -1.,  0.,  1.,  2.,  3.])
>>> int_quant(torch.tensor([-1, -0.50, 0.25, 0, 0.25, 0.5, 1]), n_bits = 3, map_to_max = True)
tensor([-127.,  -64.,   32.,    0.,   32.,   64.,  127.])
Parameters:
  • value (torch.Tensor) – A Tensor of values to be quantized

  • maintain_zero (bool) – If True, ensure that input values of zero map to zero in the output space (Default: True). If False, the output range may shift zero w.r.t. the input range.

  • map_to_max (bool) – If True the integer values in the quantized output are mapped to the extremes of bitdepth

  • nbits (int) – defines the maximum integer in the quantized tensor

  • max_n_bits (int) – maximum allowed bitdepth

Returns: torch.tensor: (q_value) quantized values

transform.torch_transform.make_act_T_config(net: TorchModule, T_fn: Callable | None = None, ModuleClass: type | None = None) Iterable | MutableMapping | Mapping[source]

Create an activity transformation configuration for a network

This helper function assists in defining an activity transformation configuration tree. It allows to to search a predefined network net to find modules of a chosen class, and specify an activity transformation for those modules.

net is a pre-defined Rockpool network.

T_fn is a Callable with signature f(x) -> x, transforming the output x of a module.

ModuleClass optionally specifies the class of module to search for in net.

You can use tree_utils.tree_update() to merge two configuration trees for different parameter families.

The resulting configuration tree will match the structure of net (or will be a sub-tree of net, including modules of type ModuleClass). You can pass this configuration tree to make_act_T_network() to build an activity transformer tree.

Parameters:
  • net (TorchModule) – A Rockpool network to build a configuration tree for

  • T_fn (Callable) – A function f(x) -> x to apply as a transformation to module output. If None, no transformation will be applied.

  • ModuleClass (Optional[type]) – A Module subclass to search for. The configuration tree will include only modules matching this class.

Returns:

An activity transformer configuration tree, to pass to make_act_T_network()

Return type:

Tree

transform.torch_transform.make_act_T_network(net: TorchModule, act_T_config: Iterable | MutableMapping | Mapping, inplace: bool = False) TorchModule[source]

Patch a Rockpool network with activity transformers

This helper function inserts ActWrapper modules into a pre-defined network net, to apply an activity transformation configuration act_T_config.

Parameters:
  • net (TorchModule) – A Rockpool network to patch

  • act_T_config (Tree) – A configuration tree from make_act_T_config()

  • inplace (bool) – If False (default), create a deep copy of net to patch. If True, patch the network in place. This in place operation does not work when patching single modules.

Returns:

The patched network

Return type:

TorchModule

transform.torch_transform.make_backward_passthrough(function: Callable) Callable[source]

Wrap a function to pass the gradient directly through in the backward pass

Parameters:

function (Callable) – A function to wrap

Returns:

A function wrapped in the backward pass

Return type:

Callable

transform.torch_transform.make_param_T_config(net: ModuleBase, T_fn: Callable, param_family: str | None = None) Iterable | MutableMapping | Mapping[source]

Helper function to build parameter transformation configuration trees

This function builds a parameter transformation nested configuration tree, based on an existing network net.

You can use tree_utils.tree_update() to merge two configuration trees for different parameter families.

The resulting configuration tree can be passed to make_param_T_network() to patch the network net.

Examples

>>> T_config = make_param_T_config(net, lambda p: p**2, 'weights')
>>> T_net = make_param_T_network(net, T_config)
Parameters:
  • net (Module) – A Rockpool network to use as a template for the transformation configuration tree

  • T_fn (Callable) – A transformation function to apply to a parameter. Must have the signature T_fn(x) -> x.

  • param_family (Optional[str]) – An optional argument to specify a parameter family. Only parameters matching this family within net will be specified in the configuration tree.

transform.torch_transform.make_param_T_network(net: ModuleBase, T_config_tree: Iterable | MutableMapping | Mapping, inplace: bool = False) TorchModule[source]

Patch a Rockpool network to apply parameter transformations in the forward pass

This helper function inserts TWrapper modules into the network tree, where required, to apply transformations to each module as defined by a configuration tree T_config_tree. Use the helper function make_param_T_config() to build configuration trees.

The resulting network will have analogous structure and behaviour to the original network, but the transformations will be applied before the forward pass of each module.

Network parameters will remain “held” by the original modules, un-transformed.

You can use the remove_T_net() function to undo this patching behaviour, restoring the original network structure but keeping any parameter modifications (e.g. training) in place.

You can use the apply_T() function to “burn in” the parameter transformation.

Parameters:
  • net (Module) – A Rockpool network to use as a template for the transformation configuration tree

  • T_config_tree (Tree) – A nested dictionary, mimicing the structure of net, specifying which parameters should be transformed and which transformation function to apply to each parameter.

  • inplace (bool) – If False (default), a deep copy of net will be created, transformed and returned. If True, the network will be patched in place.

transform.torch_transform.remove_T_net(T_net: TorchModule, inplace: bool = False) TorchModule[source]

Un-patch a transformed-patched network

This function will iterate through a network patched using make_param_T_network(), and remove all patching modules. The resulting network should have the same network architecture as the original un-patched network.

Any parameter values applied within T_net will be retained in the unpatched network.

Parameters:
  • T_net (TorchModule) – A transformer-patched network, obtained with make_param_T_network()

  • inplace (bool) – If False (default), a deep copy of net will be created, transformed and returned. If True, the network will be un-patched in place. Warning: in-place operation cannot work for single instances of TWrapper

Returns:

A network matching T_net, but with transformers removed.

Return type:

TorchModule

transform.torch_transform.stochastic_channel_rounding(value: Tensor, output_range: List[float], num_levels: int = 256, maintain_zero: bool = True)[source]

Perform stochastic rounding of a matrix, but with the input range defined automatically for each column independently

This function performs the same quantisation approach as stochastic_rounding(), but considering each column of a matrix independently. ie. per-channel.

Parameters:
  • value (torch.Tensor) – A tensor of values to quantise

  • output_range (List[float]) – Defines the destination quantisation space [min_value, max_value].

  • num_levels (int) – The number of quantisation values to round to (Default: 2**8)

  • maintain_zero (bool) – Iff True (default), input values of zero map to zero in the output space. If False, the output space may shift w.r.t. the input space. Note that the output space must be symmetric for the zero mapping to work as expected.

Returns:

The rounded values

Return type:

torch.Tensor

transform.torch_transform.stochastic_rounding(value: Tensor, input_range: List | None = None, output_range: List | None = None, num_levels: int = 256, maintain_zero: bool = True)[source]

Perform floating-point stochastic rounding on a tensor, with detailed control over quantisation levels

Stochastic rounding randomly pushes values up or down probabilistically, depending on the original value. Values will round with greater probability to their nearest quantised level, and with lower probability to their next-nearest quantised level.

For example, if we are rounding to integers, then a value of 0.1 will round down to 0.0 with 90% probability; it will round to 1.0 with 10% probability.

If we are rounding to arbitrary floating point levels, then the same logic holds, but the quantised output values will not be round numbers.

stochastic_rounding() permits the input space to be re-scaled during rounding to an output space, which will be quantised over a specified number of quantisation levels (2**8 by default). By default, the input and output space are defined to be the full input range from minimum to maximum.

stochastic_rounding() permits careful handling of symmetric or asymmetric spaces. By default, values of zero in the input space will map to zero in the output space (i.e. maintain_zero = True). In this case the output range is defined as max(abs(input_range)) * [-1, 1].

Quantisation and rounding is always to equally-spaced levels.

Examples

>>> stochastic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1.,  0.,  0.,  1.,  1.])
>>> stochastic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3)
tensor([-1., -1.,  0.,  0.,  1.])

Quantise to round integers over a defined space, changing the input scale (0..1) to (0..10).

>>> stochastic_rounding(torch.rand(10), input_range = [0., 1.], output_range = [0., 10.], num_levels = 11)
tensor([1., 9., 2., 2., 7., 1., 7., 9., 6., 1.])

Quantise to floating point levels, without changing the scale of the values.

>>> stochastic_rounding(torch.rand(10)-.5, num_levels = 3)
tensor([ 0.0000,  0.0000,  0.0000, -0.4701, -0.4701,  0.4701, -0.4701, -0.4701, 0.0000,  0.4701])
>>> stochastic_rounding(torch.rand(10)-.5, num_levels = 3)
tensor([ 0.0000,  0.0000, -0.4316,  0.0000,  0.0000,  0.4316, -0.4316,  0.0000, 0.4316,  0.4316])

Note that the scale is defined by the observed range of the random values, in this case.

Parameters:
  • value (torch.Tensor) – A Tensor of values to round

  • input_range (Optional[List]) – If defined, a specific input range to use ([min_value, max_value]), as floating point numbers. If None (default), use the range of input values to define the input range.

  • output_range (Optional[List]) – If defined, a specific output range to use ([min_value, max_value]), as floating point numbers. If None (default), use the range of input values to define the input range.

  • num_levels (int) – The number of output levels to quantise to (Default: 2**8)

  • maintain_zero (bool) – If True, ensure that input values of zero map to zero in the output space (Default: True). If False, the output range may shift zero w.r.t. the input range.

Returns:

Floating point stochastically rounded values

Return type:

torch.Tensor

transform.torch_transform.t_decay(decay: Tensor, dt: float = 0.001)[source]

quantizes decay factor (exp (-dt/tau)) of LIF neurons: alpha and beta respectively for Vmem and Isyn the quantization is done based one converting the decay to bitshoft subtraction and reconstructing decay. the trasnformation is passed to make_backward_passthrough function that applies it in the forward pass of Torch module

note: this transform can be used only if decay_training = True atleast one of LIF modules in teh network .. rubric:: Examples

applying on tensors: >>> alpha = torch.rand(10) tensor([0.4156, 0.8306, 0.1557, 0.9475, 0.4532, 0.3585, 0.4014, 0.9380, 0.9426, 0.7212]) >>> tt.t_decay(alpha) tensor([0.0000, 0.7500, 0.0000, 0.9375, 0.0000, 0.0000, 0.0000, 0.9375, 0.9375, 0.7500])

applying to decay parameter family of a network:

>>> net = Sequential(LinearTorch((2,2)), LIFTorch(2, decay_training=True))
>>> tconfig = tt.make_param_T_config(net, lambda p : tt.t_decay(p), 'decays')
: {'1_LIFTorch': {'alpha': <function <lambda> at 0x7fb32efbb280>, 'beta': <function <lambda> at 0x7fb32efbb280>}}
>>> T_net = make_param_T_network(net, T_config)
    : TorchSequential  with shape (2, 2) {
        LinearTorch '0_LinearTorch' with shape (2, 2)
        TWrapper '1_LIFTorch' with shape (2, 2) {
            LIFTorch '_mod' with shape (2, 2) }}

Args: decay (torch.Tensor) : alpha and beta parameters from decay parameter family of LIF neurons dt (float) : Euler simulator time-step in seconds

Returns: q_decay (torch.Tensor)

Classes

class transform.torch_transform.ActWrapper(*args, **kwargs)[source]

A wrapper module that applies an output activity transformation after evolution

This module is not designed to be user-facing. Users should use the helper functions make_act_T_config() and make_act_T_network(). This approach will insert ActWrapper modules into the network as required.

__init__(mod: TorchModule, trans_Fn: Callable | None = None, *args, **kwargs)[source]

Instantiate an ActWrapper object

mod is a Rockpool module. An ActWrapper will be created to wrap mod. The transformation function trans_Fn will be applied to the outputs of mod during evolution.

Parameters:
  • mod (TorchModule) – A module to patch

  • trans_Fn (Optional(Callable)) – A transformation function to apply to the outputs of mod. If None, no transformation will be applied.

as_graph() GraphModuleBase[source]

Convert this module to a computational graph

Returns:

The computational graph corresponding to this module

Return type:

GraphModuleBase

Raises:

NotImplementedError – If as_graph() is not implemented for this subclass

forward(*args, **kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transform.torch_transform.TWrapper(*args, **kwargs)[source]

A wrapper for a Rockpool TorchModule, implementing a parameter transformation in the forward pass

This module is not designed to be be user-facing; you should probably use the helper functions make_param_T_config() and make_param_T_network() to patch a Rockpool network. This will insert TWrapper modules into the network architecture as required.

__init__(mod: TorchModule, T_config: Iterable | MutableMapping | Mapping | None = None, *args, **kwargs)[source]

Initialise a parameter transformer wrapper module

mod is a Rockpool module with some set of attributes. T_config is a dictionary, with keys optionally matching the attributes of mod. Each value must be a callable T_Fn(a) -> a which can transform the associated attribute a.

A TWrapper module will be created, with mod as a sub-module. The TWrapper will apply the specified transformations to all the attributes of mod at the beginning of the forward-pass of evolution, then evolve mod with the transformed attributes.

Users should use the helper functions make_param_T_config() and make_param_T_network().

Parameters:
  • mod (TorchModule) – A Rockpool module to apply parameter transformations to

  • T_config (Optional[Tree]) – A nested dictionary specifying which transofmration transformations to apply to specific parameters. Each transformation function must be specified as a Callable with a key identical to a parameter of mod. If None, do not apply any transformation to mod.

as_graph() GraphModuleBase[source]

Convert this module to a computational graph

Returns:

The computational graph corresponding to this module

Return type:

GraphModuleBase

Raises:

NotImplementedError – If as_graph() is not implemented for this subclass

forward(*args, **kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transform.torch_transform.class_calc_q_decay(dt)[source]

function used to calculate bitshift equivalent of decay (exp(-dt/tau))

Args: decay (torch.Tensor) : alpha and beta parameters from decay parameter family of LIF neurons dt (float) : Euler simulator time-step in seconds

Returns: q_decay (torch.Tensor)

__init__(dt)[source]