Module training.torch_loss
Torch loss functions and regularizers useful for training networks using Torch Modules.
Functions overview
|
Impose a cost on parameters that violate bounds constraints |
|
Convenience function to build a bounds template for a problem |
|
Compute the summed exponential error of boundary violations of an input. |
Classes overview
|
Class wrapper for the summed exponential error of boundary violations of an input. |
Functions
- training.torch_loss.bounds_cost(params: dict, lower_bounds: dict, upper_bounds: dict) Tensor [source]
Impose a cost on parameters that violate bounds constraints
This function works hand-in-hand with
make_bounds()
to enforce greater-than and less-than constraints on parameter values. This is designed to be used as a component of a loss function, to ensure parameter values fall in a reasonable range.bounds_cost()
imposes a value of 1.0 for each parameter element that exceeds a bound infinitesimally, increasing exponentially as the bound is exceeded, or 0.0 for each parameter within the bounds. You will most likely want to scale this by a penalty factor within your cost function.Warning
bounds_cost()
does not clip parameters to the bounds. It is possible for parameters to exceed the bounds during optimisation. If this must be prevented, you should clip the parameters explicitly.See also
See 🏃🏽♀️ Training a Rockpool network with Jax for examples for using
make_bounds()
andbounds_cost()
.- Parameters:
params (dict) – A dictionary of parameters over which to impose bounds
lower_bounds (dict) – A dictionary of lower bounds for parameters matching your model, modified from that returned by
make_bounds()
upper_bounds (dict) – A dictionary of upper bounds for parameters matching your model, modified from that returned by
make_bounds()
- Returns:
The cost to include in the cost function.
- Return type:
float
- training.torch_loss.make_bounds(params: dict) Tuple[dict, dict] [source]
Convenience function to build a bounds template for a problem
This function works hand-in-hand with
bounds_cost()
, to enforce minimum and/or maximum parameter bounds.make_bounds()
accepts a set of parameters (e.g. as returned from theModule.parameters()
method), and returns a ready-made dictionary of bounds (with no restrictions by default).See also
See 🏃🏽♀️ Training a Rockpool network with Jax for examples for using
make_bounds()
andbounds_cost()
.make_bounds()
returns two dictionaries, representing the lower and upper bounds respectively. Initially all entries will be set to-np.inf
andnp.inf
, indicating that no bounds should be enforced. You must edit these dictionaries to set the bounds.- Parameters:
params (dict) – Dictionary of parameters defining an optimisation problem. This can be provided as the parameter dictionary returned by
Module.parameters()
.- Returns:
lower_bounds
,upper_bounds
. Each dictionary mimics the structure ofparams
, with initial bounds set to-np.inf
andnp.inf
(i.e. no bounds enforced).- Return type:
(dict, dict)
- training.torch_loss.summed_exp_boundary_loss(data, lower_bound=None, upper_bound=None)[source]
Compute the summed exponential error of boundary violations of an input.
\[ \begin{align}\begin{aligned}\textrm{sebl}(y, y_{lower}, y_{upper}) = \sum_i \textrm{sebl}(y_i, y_{lower}, y_{upper})\\\begin{split}\textrm{sebl}(y_i, y_{lower}, y_{upper}) = \begin{cases} \exp(y_i - y_{upper}), & \text{if $y_i > y_{upper}$} \\ \exp(y_{lower} - y_i), & \text{if $y_i < y_{lower}$} \\ 0, & \text{otherwise} \\ \end{cases}\end{split}\end{aligned}\end{align} \]This function allows for soft parameter constraints by creating a loss for boundary violations. This can be reached by adding
summed_exp_boundary_loss(data, lower_bound, upper_bound)
to your general loss, wheredata
is an arbitrary tensor and both bounds are scalars. If either of the bounds is given asNone
, its boundary will not be penalized.In the example below we will introduce soft constraints to
tau_mem
of the first layer of the model, so that valuestau_mem > 1e-1
andtau_mem < 1e-3
will be punished and considered in the optimization step.# Calculate the training loss y_hat, _, _ = model(x) train_loss = F.mse_loss(y, y_hat) # Set soft constraints to the time constants of the first layer of the Parameter boundary_loss = summed_exp_boundary_loss(model[0].tau_mem, 1e-3, 1e-1) complete_loss = train_loss + boundary_loss # Do backpropagation over both losses and optimize the model parameters accordingly complete_loss.backward() optimizer.step()
If we would only like to introduce a lower bound penalty to a parameter we can easily do that by leaving away the definition for
upper_bound
. The same works analogously for only penalizing upper bounds.boundary_loss = summed_exp_boundary_loss(model[0].thr_up, lower_bound=1e-4) complete_loss = train_loss + boundary_loss # Do backpropagation over both losses and optimize the model parameters accordingly complete_loss.backward() optimizer.step()
- Parameters:
data (torch.Tensor) – The data which boundary violations will be penalized, with shape (N,).
lower_bound (float) – Lower bound for the data.
upper_bound (float) – Upper bound for the data.
- Returns:
Summed exponential error of boundary violations.
- Return type:
float
Classes
- class training.torch_loss.ParameterBoundaryRegularizer(*args, **kwargs)[source]
Class wrapper for the summed exponential error of boundary violations of an input. See
summed_exp_boundary_loss()
for more information. Allows to define the boundaries of a value just once in an object.- __init__(lower_bound=None, upper_bound=None)[source]
Initialise this module
You must override this method to initialise your module.
- Parameters:
*args –
**kwargs –
- forward(input)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.