rockpool.transform.torch_transform

Defines the parameter and activation transformation-in-training pipeline for TorchModule s

Examples

Construct a network, and patch it to round each weight parameter:

>>> net = Sequential(...)
>>> T_fn = lambda p: stochastic_rounding(p, num_levels = 2**num_bits)
>>> T_config = make_param_T_config(net, T_fn, 'weights')
>>> T_net = make_param_T_network(net, T_config)

Train here. To burn-in and remove the transformations:

>>> burned_in_net = apply_T(T_net)
>>> unpatched_net = remove_T_net(burned_in_net)

Functions

apply_T(T_net[, inplace])

"Burn in" a set of parameter transformations, applying each transformation and storing the resulting transformed parameters

deterministic_rounding(value[, input_range, ...])

Quantise values by shifting them to the closest quantisation level

dropout(param[, dropout_prob])

Randomly set values of a tensor to 0., with a defined probability

int_quant(value[, maintain_zero, ...])

Transforms a tensor to a quantized space with a range of integer values defined by n_bits

make_act_T_config(net[, T_fn, ModuleClass])

Create an activity transformation configuration for a network

make_act_T_network(net, act_T_config[, inplace])

Patch a Rockpool network with activity transformers

make_backward_passthrough(function)

Wrap a function to pass the gradient directly through in the backward pass

make_param_T_config(net, T_fn[, param_family])

Helper function to build parameter transformation configuration trees

make_param_T_network(net, T_config_tree[, ...])

Patch a Rockpool network to apply parameter transformations in the forward pass

remove_T_net(T_net[, inplace])

Un-patch a transformed-patched network

stochastic_channel_rounding(value, output_range)

Perform stochastic rounding of a matrix, but with the input range defined automatically for each column independently

stochastic_rounding(value[, input_range, ...])

Perform floating-point stochastic rounding on a tensor, with detailed control over quantisation levels

t_decay(decay[, dt])

quantizes decay factor (exp (-dt/tau)) of LIF neurons: alpha and beta respectively for Vmem and Isyn the quantization is done based one converting the decay to bitshoft subtraction and reconstructing decay.

Classes

ActWrapper(*args, **kwargs)

A wrapper module that applies an output activity transformation after evolution

TWrapper(*args, **kwargs)

A wrapper for a Rockpool TorchModule, implementing a parameter transformation in the forward pass

class_calc_q_decay(dt)

function used to calculate bitshift equivalent of decay (exp(-dt/tau))