Module transform.torch_transform
Defines the parameter and activation transformation-in-training pipeline for TorchModule
s
Examples
Construct a network, and patch it to round each weight parameter:
>>> net = Sequential(...)
>>> T_fn = lambda p: stochastic_rounding(p, num_levels = 2**num_bits)
>>> T_config = make_param_T_config(net, T_fn, 'weights')
>>> T_net = make_param_T_network(net, T_config)
Train here. To burn-in and remove the transformations:
>>> burned_in_net = apply_T(T_net)
>>> unpatched_net = remove_T_net(burned_in_net)
Functions overview
|
"Burn in" a set of parameter transformations, applying each transformation and storing the resulting transformed parameters |
|
Quantise values by shifting them to the closest quantisation level |
|
Randomly set values of a tensor to |
|
Transforms a tensor to a quantized space with a range of integer values defined by n_bits |
|
Create an activity transformation configuration for a network |
|
Patch a Rockpool network with activity transformers |
|
Wrap a function to pass the gradient directly through in the backward pass |
|
Helper function to build parameter transformation configuration trees |
|
Patch a Rockpool network to apply parameter transformations in the forward pass |
|
Un-patch a transformed-patched network |
|
Perform stochastic rounding of a matrix, but with the input range defined automatically for each column independently |
|
Perform floating-point stochastic rounding on a tensor, with detailed control over quantisation levels |
|
quantizes decay factor (exp (-dt/tau)) of LIF neurons: alpha and beta respectively for Vmem and Isyn the quantization is done based one converting the decay to bitshoft subtraction and reconstructing decay. |
Classes overview
|
A wrapper module that applies an output activity transformation after evolution |
|
A wrapper for a Rockpool TorchModule, implementing a parameter transformation in the forward pass |
function used to calculate bitshift equivalent of decay (exp(-dt/tau)) |
Functions
- transform.torch_transform.apply_T(T_net: TorchModule, inplace: bool = False) TorchModule [source]
“Burn in” a set of parameter transformations, applying each transformation and storing the resulting transformed parameters
This function takes a transformer-patched network
net
, and applies the pre-specified transformations to each parameter. The resulting transformed parameters are then stored within the parameters of the network.This is a useful step after training, as part of extracting the transformed parameters from the trained network.
The helper function
remove_T_net()
can be used afterwards to remove the transformer patches from the network.Examples
>>> T_net = make_param_T_network(net, T_config)
At this point,
T_net.parameters()
will return un-transformed parameters.>>> T_net.apply_T()
Now
T_net.parameters()
will contain the results of applying the transformation to each parameter.- Parameters:
T_net (TorchModule) – A transformer-patched network, obtained with
make_param_T_network()
inplace (bool) – If
False
(default), a deep copy ofnet
will be created, transformed and returned. IfTrue
, the network will be transformed in place.
- transform.torch_transform.deterministic_rounding(value: Tensor, input_range: List | None = None, output_range: List | None = None, num_levels: int = 256, maintain_zero: bool = True)[source]
Quantise values by shifting them to the closest quantisation level
This is a floating-point equivalent to standard integer rounding (e.g. using
torch.round()
).deterministic_rounding()
provides fine control over input and output spaces, as well as numbers of levels to quantise to, and can round to floating point levels instead of round numbers.deterministic_rounding()
always leaves values as floating point.For example, if we are rounding to integers, then a value of
0.1
will round down to0.
. A value of0.5
will round up to1.
. A value of0.9
will round up to1.
.If we are rounding to arbitrary floating point levels, then the same logic holds, but the quantised output values will not be round numbers, but will be the nearest floating point quantisation level.
deterministic_rounding()
permits the input space to be re-scaled during rounding to an output space, which will be quantised over a specified number of quantisation levels (2**8
by default). By default, the input and output space are defined to be the full input range from minimum to maximum.deterministic_rounding()
permits careful handling of symmetric or asymmetric spaces. By default, values of zero in the input space will map to zero in the output space (i.e.maintain_zero = True
). In this case the output range is defined asmax(abs(input_range)) * [-1, 1]
.Quantisation and rounding is always to equally-spaced levels.
Examples
>>> deterministic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3) tensor([-1., -1., 0., 1., 1.])
>>> deterministic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3) tensor([-1., -1., 0., 0., 1.])
Round to integer values (-10..10)
>>> deterministic_rounding(torch.rand(10)-.5, output_range=[-10., 10.], num_levels = 21) tensor([ 10., -3., 3., -10., 5., 0., 9., -5., 7., 8.])
value (torch.Tensor): A Tensor of values to round input_range (Optional[List]): If defined, a specific input range to use (
[min_value, max_value]
), as floating point numbers. IfNone
(default), use the range of input values to define the input range. output_range (Optional[List]): If defined, a specific output range to use ([min_value, max_value]
), as floating point numbers. IfNone
(default), use the range of input values to define the input range. num_levels (int): The number of output levels to quantise to (Default:2**8
) maintain_zero (bool): IffTrue
, ensure that input values of zero map to zero in the output space (Default:True
). IfFalse
, the output range may shift zero w.r.t. the input range.- Returns:
Floating point rounded values
- Return type:
torch.Tensor
- transform.torch_transform.dropout(param: Tensor, dropout_prob: float = 0.5)[source]
Randomly set values of a tensor to
0.
, with a defined probabilityDropout is used to improve the robustness of a network, by reducing the dependency of a network on any given parameter value. This is accomplished by randomly setting parameters (usually weights) to zero during training, such that the parameters are ignored.
For a
dropout_prob = 0.8
, each parameters is randomly set to zero with 80% probability.Examples
>>> dropout(torch.ones(10)) tensor([0., 0., 0., 1., 0., 1., 1., 1., 1., 0.])
>>> dropout(torch.ones(10), dropout_prob = 0.8) tensor([1., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
- Parameters:
param (torch.Tensor) – A tensor of values to dropout
dropout_prob (float) – The probability of zeroing each parameter value (Default:
0.5
, 50%)
- Returns:
The tensor of values, with elements dropped out probabilistically
- Return type:
torch.Tensor
- transform.torch_transform.int_quant(value: Tensor, maintain_zero: bool = True, map_to_max: bool = False, n_bits: int = 8, max_n_bits: int = 8)[source]
Transforms a tensor to a quantized space with a range of integer values defined by n_bits
Examples
>>> int_quant(torch.tensor([-1, -0.50, 0.25, 0, 0.25, 0.5, 1]), n_bits = 3, map_to_max = False) tensor([-3., -2., -1., 0., 1., 2., 3.]) >>> int_quant(torch.tensor([-1, -0.50, 0.25, 0, 0.25, 0.5, 1]), n_bits = 3, map_to_max = True) tensor([-127., -64., 32., 0., 32., 64., 127.])
- Parameters:
value (torch.Tensor) – A Tensor of values to be quantized
maintain_zero (bool) – If
True
, ensure that input values of zero map to zero in the output space (Default:True
). IfFalse
, the output range may shift zero w.r.t. the input range.map_to_max (bool) – If True the integer values in the quantized output are mapped to the extremes of bitdepth
nbits (int) – defines the maximum integer in the quantized tensor
max_n_bits (int) – maximum allowed bitdepth
Returns: torch.tensor: (q_value) quantized values
- transform.torch_transform.make_act_T_config(net: TorchModule, T_fn: Callable | None = None, ModuleClass: type | None = None) Iterable | MutableMapping | Mapping [source]
Create an activity transformation configuration for a network
This helper function assists in defining an activity transformation configuration tree. It allows to to search a predefined network
net
to find modules of a chosen class, and specify an activity transformation for those modules.net
is a pre-defined Rockpool network.T_fn
is aCallable
with signaturef(x) -> x
, transforming the outputx
of a module.ModuleClass
optionally specifies the class of module to search for innet
.You can use
tree_utils.tree_update()
to merge two configuration trees for different parameter families.The resulting configuration tree will match the structure of
net
(or will be a sub-tree ofnet
, including modules of typeModuleClass
). You can pass this configuration tree tomake_act_T_network()
to build an activity transformer tree.- Parameters:
net (TorchModule) – A Rockpool network to build a configuration tree for
T_fn (Callable) – A function
f(x) -> x
to apply as a transformation to module output. IfNone
, no transformation will be applied.ModuleClass (Optional[type]) – A
Module
subclass to search for. The configuration tree will include only modules matching this class.
- Returns:
An activity transformer configuration tree, to pass to
make_act_T_network()
- Return type:
Tree
- transform.torch_transform.make_act_T_network(net: TorchModule, act_T_config: Iterable | MutableMapping | Mapping, inplace: bool = False) TorchModule [source]
Patch a Rockpool network with activity transformers
This helper function inserts
ActWrapper
modules into a pre-defined networknet
, to apply an activity transformation configurationact_T_config
.- Parameters:
net (TorchModule) – A Rockpool network to patch
act_T_config (Tree) – A configuration tree from
make_act_T_config()
inplace (bool) – If
False
(default), create a deep copy ofnet
to patch. IfTrue
, patch the network in place. This in place operation does not work when patching single modules.
- Returns:
The patched network
- Return type:
- transform.torch_transform.make_backward_passthrough(function: Callable) Callable [source]
Wrap a function to pass the gradient directly through in the backward pass
- Parameters:
function (Callable) – A function to wrap
- Returns:
A function wrapped in the backward pass
- Return type:
Callable
- transform.torch_transform.make_param_T_config(net: ModuleBase, T_fn: Callable, param_family: str | None = None) Iterable | MutableMapping | Mapping [source]
Helper function to build parameter transformation configuration trees
This function builds a parameter transformation nested configuration tree, based on an existing network
net
.You can use
tree_utils.tree_update()
to merge two configuration trees for different parameter families.The resulting configuration tree can be passed to
make_param_T_network()
to patch the networknet
.Examples
>>> T_config = make_param_T_config(net, lambda p: p**2, 'weights') >>> T_net = make_param_T_network(net, T_config)
- Parameters:
net (Module) – A Rockpool network to use as a template for the transformation configuration tree
T_fn (Callable) – A transformation function to apply to a parameter. Must have the signature
T_fn(x) -> x
.param_family (Optional[str]) – An optional argument to specify a parameter family. Only parameters matching this family within
net
will be specified in the configuration tree.
- transform.torch_transform.make_param_T_network(net: ModuleBase, T_config_tree: Iterable | MutableMapping | Mapping, inplace: bool = False) TorchModule [source]
Patch a Rockpool network to apply parameter transformations in the forward pass
This helper function inserts
TWrapper
modules into the network tree, where required, to apply transformations to each module as defined by a configuration treeT_config_tree
. Use the helper functionmake_param_T_config()
to build configuration trees.The resulting network will have analogous structure and behaviour to the original network, but the transformations will be applied before the forward pass of each module.
Network parameters will remain “held” by the original modules, un-transformed.
You can use the
remove_T_net()
function to undo this patching behaviour, restoring the original network structure but keeping any parameter modifications (e.g. training) in place.You can use the
apply_T()
function to “burn in” the parameter transformation.- Parameters:
net (Module) – A Rockpool network to use as a template for the transformation configuration tree
T_config_tree (Tree) – A nested dictionary, mimicing the structure of
net
, specifying which parameters should be transformed and which transformation function to apply to each parameter.inplace (bool) – If
False
(default), a deep copy ofnet
will be created, transformed and returned. IfTrue
, the network will be patched in place.
- transform.torch_transform.remove_T_net(T_net: TorchModule, inplace: bool = False) TorchModule [source]
Un-patch a transformed-patched network
This function will iterate through a network patched using
make_param_T_network()
, and remove all patching modules. The resulting network should have the same network architecture as the original un-patched network.Any parameter values applied within
T_net
will be retained in the unpatched network.- Parameters:
T_net (TorchModule) – A transformer-patched network, obtained with
make_param_T_network()
inplace (bool) – If
False
(default), a deep copy ofnet
will be created, transformed and returned. IfTrue
, the network will be un-patched in place. Warning: in-place operation cannot work for single instances ofTWrapper
- Returns:
A network matching
T_net
, but with transformers removed.- Return type:
- transform.torch_transform.stochastic_channel_rounding(value: Tensor, output_range: List[float], num_levels: int = 256, maintain_zero: bool = True)[source]
Perform stochastic rounding of a matrix, but with the input range defined automatically for each column independently
This function performs the same quantisation approach as
stochastic_rounding()
, but considering each column of a matrix independently. ie. per-channel.- Parameters:
value (torch.Tensor) – A tensor of values to quantise
output_range (List[float]) – Defines the destination quantisation space
[min_value, max_value]
.num_levels (int) – The number of quantisation values to round to (Default:
2**8
)maintain_zero (bool) – Iff
True
(default), input values of zero map to zero in the output space. IfFalse
, the output space may shift w.r.t. the input space. Note that the output space must be symmetric for the zero mapping to work as expected.
- Returns:
The rounded values
- Return type:
torch.Tensor
- transform.torch_transform.stochastic_rounding(value: Tensor, input_range: List | None = None, output_range: List | None = None, num_levels: int = 256, maintain_zero: bool = True)[source]
Perform floating-point stochastic rounding on a tensor, with detailed control over quantisation levels
Stochastic rounding randomly pushes values up or down probabilistically, depending on the original value. Values will round with greater probability to their nearest quantised level, and with lower probability to their next-nearest quantised level.
For example, if we are rounding to integers, then a value of 0.1 will round down to 0.0 with 90% probability; it will round to 1.0 with 10% probability.
If we are rounding to arbitrary floating point levels, then the same logic holds, but the quantised output values will not be round numbers.
stochastic_rounding()
permits the input space to be re-scaled during rounding to an output space, which will be quantised over a specified number of quantisation levels (2**8
by default). By default, the input and output space are defined to be the full input range from minimum to maximum.stochastic_rounding()
permits careful handling of symmetric or asymmetric spaces. By default, values of zero in the input space will map to zero in the output space (i.e.maintain_zero = True
). In this case the output range is defined asmax(abs(input_range)) * [-1, 1]
.Quantisation and rounding is always to equally-spaced levels.
Examples
>>> stochastic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3) tensor([-1., 0., 0., 1., 1.])
>>> stochastic_rounding(torch.tensor([-1., -0.5, 0., 0.5, 1.]), num_levels = 3) tensor([-1., -1., 0., 0., 1.])
Quantise to round integers over a defined space, changing the input scale (0..1) to (0..10).
>>> stochastic_rounding(torch.rand(10), input_range = [0., 1.], output_range = [0., 10.], num_levels = 11) tensor([1., 9., 2., 2., 7., 1., 7., 9., 6., 1.])
Quantise to floating point levels, without changing the scale of the values.
>>> stochastic_rounding(torch.rand(10)-.5, num_levels = 3) tensor([ 0.0000, 0.0000, 0.0000, -0.4701, -0.4701, 0.4701, -0.4701, -0.4701, 0.0000, 0.4701])
>>> stochastic_rounding(torch.rand(10)-.5, num_levels = 3) tensor([ 0.0000, 0.0000, -0.4316, 0.0000, 0.0000, 0.4316, -0.4316, 0.0000, 0.4316, 0.4316])
Note that the scale is defined by the observed range of the random values, in this case.
- Parameters:
value (torch.Tensor) – A Tensor of values to round
input_range (Optional[List]) – If defined, a specific input range to use (
[min_value, max_value]
), as floating point numbers. IfNone
(default), use the range of input values to define the input range.output_range (Optional[List]) – If defined, a specific output range to use (
[min_value, max_value]
), as floating point numbers. IfNone
(default), use the range of input values to define the input range.num_levels (int) – The number of output levels to quantise to (Default:
2**8
)maintain_zero (bool) – If
True
, ensure that input values of zero map to zero in the output space (Default:True
). IfFalse
, the output range may shift zero w.r.t. the input range.
- Returns:
Floating point stochastically rounded values
- Return type:
torch.Tensor
- transform.torch_transform.t_decay(decay: Tensor, dt: float = 0.001)[source]
quantizes decay factor (exp (-dt/tau)) of LIF neurons: alpha and beta respectively for Vmem and Isyn the quantization is done based one converting the decay to bitshoft subtraction and reconstructing decay. the trasnformation is passed to make_backward_passthrough function that applies it in the forward pass of Torch module
note: this transform can be used only if decay_training = True atleast one of LIF modules in teh network .. rubric:: Examples
applying on tensors: >>> alpha = torch.rand(10) tensor([0.4156, 0.8306, 0.1557, 0.9475, 0.4532, 0.3585, 0.4014, 0.9380, 0.9426, 0.7212]) >>> tt.t_decay(alpha) tensor([0.0000, 0.7500, 0.0000, 0.9375, 0.0000, 0.0000, 0.0000, 0.9375, 0.9375, 0.7500])
applying to decay parameter family of a network:
>>> net = Sequential(LinearTorch((2,2)), LIFTorch(2, decay_training=True)) >>> tconfig = tt.make_param_T_config(net, lambda p : tt.t_decay(p), 'decays') : {'1_LIFTorch': {'alpha': <function <lambda> at 0x7fb32efbb280>, 'beta': <function <lambda> at 0x7fb32efbb280>}} >>> T_net = make_param_T_network(net, T_config) : TorchSequential with shape (2, 2) { LinearTorch '0_LinearTorch' with shape (2, 2) TWrapper '1_LIFTorch' with shape (2, 2) { LIFTorch '_mod' with shape (2, 2) }}
Args: decay (torch.Tensor) : alpha and beta parameters from decay parameter family of LIF neurons dt (float) : Euler simulator time-step in seconds
Returns: q_decay (torch.Tensor)
Classes
- class transform.torch_transform.ActWrapper(*args, **kwargs)[source]
A wrapper module that applies an output activity transformation after evolution
This module is not designed to be user-facing. Users should use the helper functions
make_act_T_config()
andmake_act_T_network()
. This approach will insertActWrapper
modules into the network as required.- __init__(mod: TorchModule, trans_Fn: Callable | None = None, *args, **kwargs)[source]
Instantiate an ActWrapper object
mod
is a Rockpool module. AnActWrapper
will be created to wrapmod
. The transformation functiontrans_Fn
will be applied to the outputs ofmod
during evolution.- Parameters:
mod (TorchModule) – A module to patch
trans_Fn (Optional(Callable)) – A transformation function to apply to the outputs of
mod
. IfNone
, no transformation will be applied.
- as_graph() GraphModuleBase [source]
Convert this module to a computational graph
- Returns:
The computational graph corresponding to this module
- Return type:
- Raises:
NotImplementedError – If
as_graph()
is not implemented for this subclass
- forward(*args, **kwargs)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transform.torch_transform.TWrapper(*args, **kwargs)[source]
A wrapper for a Rockpool TorchModule, implementing a parameter transformation in the forward pass
This module is not designed to be be user-facing; you should probably use the helper functions
make_param_T_config()
andmake_param_T_network()
to patch a Rockpool network. This will insertTWrapper
modules into the network architecture as required.- __init__(mod: TorchModule, T_config: Iterable | MutableMapping | Mapping | None = None, *args, **kwargs)[source]
Initialise a parameter transformer wrapper module
mod
is a Rockpool module with some set of attributes.T_config
is a dictionary, with keys optionally matching the attributes ofmod
. Each value must be a callableT_Fn(a) -> a
which can transform the associated attributea
.A
TWrapper
module will be created, withmod
as a sub-module. TheTWrapper
will apply the specified transformations to all the attributes ofmod
at the beginning of the forward-pass of evolution, then evolvemod
with the transformed attributes.Users should use the helper functions
make_param_T_config()
andmake_param_T_network()
.- Parameters:
mod (TorchModule) – A Rockpool module to apply parameter transformations to
T_config (Optional[Tree]) – A nested dictionary specifying which transofmration transformations to apply to specific parameters. Each transformation function must be specified as a Callable with a key identical to a parameter of
mod
. IfNone
, do not apply any transformation tomod
.
- as_graph() GraphModuleBase [source]
Convert this module to a computational graph
- Returns:
The computational graph corresponding to this module
- Return type:
- Raises:
NotImplementedError – If
as_graph()
is not implemented for this subclass
- forward(*args, **kwargs)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transform.torch_transform.class_calc_q_decay(dt)[source]
function used to calculate bitshift equivalent of decay (exp(-dt/tau))
Args: decay (torch.Tensor) : alpha and beta parameters from decay parameter family of LIF neurons dt (float) : Euler simulator time-step in seconds
Returns: q_decay (torch.Tensor)