# Source code for training.torch_loss

"""
Torch loss functions and regularizers useful for training networks using Torch Modules.
"""

from rockpool.utilities.backend_management import backend_available

if not backend_available("torch"):
raise ModuleNotFoundError(
"Torch backend not found. Modules that rely on Torch will not be available."
)

from rockpool.nn.modules import TorchModule
import torch

from copy import deepcopy

from typing import Tuple

import rockpool.utilities.tree_utils as tu

__all__ = [
"summed_exp_boundary_loss",
"ParameterBoundaryRegularizer",
"make_bounds",
"bounds_cost",
]

[docs]def summed_exp_boundary_loss(data, lower_bound=None, upper_bound=None):
"""
Compute the summed exponential error of boundary violations of an input.

.. math::

\\textrm{sebl}(y, y_{lower}, y_{upper}) = \\sum_i \\textrm{sebl}(y_i, y_{lower}, y_{upper})

\\textrm{sebl}(y_i, y_{lower}, y_{upper}) =
\\begin{cases}
\\exp(y_i - y_{upper}),  & \\text{if $y_i > y_{upper}$} \\\\
\\exp(y_{lower} - y_i),  & \\text{if $y_i < y_{lower}$} \\\\
0,  & \\text{otherwise} \\\\
\\end{cases}

This function allows for soft parameter constraints by creating a loss for boundary violations. This can be reached by adding summed_exp_boundary_loss(data, lower_bound, upper_bound) to your general loss, where data is an arbitrary tensor and both bounds are scalars. If either of the bounds is given as None, its boundary will not be penalized.

In the example below we will introduce soft constraints to tau_mem of the first layer of the model, so that values tau_mem > 1e-1 and tau_mem < 1e-3 will be punished and considered in the optimization step.

.. code-block:: python

# Calculate the training loss
y_hat, _, _ = model(x)
train_loss = F.mse_loss(y, y_hat)

# Set soft constraints to the time constants of the first layer of the Parameter
boundary_loss = summed_exp_boundary_loss(model.tau_mem, 1e-3, 1e-1)
complete_loss = train_loss + boundary_loss

# Do backpropagation over both losses and optimize the model parameters accordingly
complete_loss.backward()
optimizer.step()

If we would only like to introduce a lower bound penalty to a parameter we can easily do that by leaving away the definition for upper_bound. The same works analogously for only penalizing upper bounds.

.. code-block:: python

boundary_loss = summed_exp_boundary_loss(model.thr_up, lower_bound=1e-4)
complete_loss = train_loss + boundary_loss

# Do backpropagation over both losses and optimize the model parameters accordingly
complete_loss.backward()
optimizer.step()

Args:
data (torch.Tensor): The data which boundary violations will be penalized, with shape (N,).
lower_bound (float): Lower bound for the data.
upper_bound (float): Upper bound for the data.

Returns:
float: Summed exponential error of boundary violations.

"""
# - If upper_bound is given, calculate the loss, otherwise skip it
if upper_bound:
upper_loss = torch.exp(data - upper_bound)

# - Only count the loss when a violation occured, in which case exp(y_i - y_upper) > 1
upper_loss = torch.sum(upper_loss[upper_loss > 1])
else:
upper_loss = 0.0

# - If lower_bound is given, calculate the loss, otherwise skip it
if lower_bound:
lower_loss = torch.exp(lower_bound - data)

# - Only count the loss when a violation occured, in which case exp(y_lower - y_i) > 1
lower_loss = torch.sum(lower_loss[lower_loss > 1])
else:
lower_loss = 0.0

return lower_loss + upper_loss

[docs]class ParameterBoundaryRegularizer(TorchModule):
"""
Class wrapper for the summed exponential error of boundary violations of an input. See :py:func:.summed_exp_boundary_loss for more information.
Allows to define the boundaries of a value just once in an object.
"""

[docs]    def __init__(self, lower_bound=None, upper_bound=None):
super().__init__()
self.lower_bound = lower_bound
self.upper_bound = upper_bound

[docs]    def forward(self, input):
return summed_exp_boundary_loss(input, self.lower_bound, self.upper_bound)

[docs]def make_bounds(params: dict) -> Tuple[dict, dict]:
"""
Convenience function to build a bounds template for a problem

This function works hand-in-hand with :py:func:.bounds_cost, to enforce minimum and/or maximum parameter bounds. :py:func:.make_bounds accepts a set of parameters (e.g. as returned from the :py:meth:Module.parameters method), and returns a ready-made dictionary of bounds (with no restrictions by default).

See :ref:/in-depth/jax-training.ipynb for examples for using :py:func:.make_bounds and :py:func:.bounds_cost.

:py:func:.make_bounds returns two dictionaries, representing the lower and upper bounds respectively. Initially all entries will be set to -np.inf and np.inf, indicating that no bounds should be enforced. You must edit these dictionaries to set the bounds.

Args:
params (dict): Dictionary of parameters defining an optimisation problem. This can be provided as the parameter dictionary returned by :py:meth:Module.parameters.

Returns:
(dict, dict): lower_bounds, upper_bounds. Each dictionary mimics the structure of params, with initial bounds set to -np.inf and np.inf (i.e. no bounds enforced).
"""
# - Make copies
lower_bounds = deepcopy(params)
upper_bounds = deepcopy(params)

# - Reset to -inf and inf
lower_bounds = tu.tree_map(lower_bounds, lambda _: -float("inf"))
upper_bounds = tu.tree_map(upper_bounds, lambda _: float("inf"))

return lower_bounds, upper_bounds

[docs]def bounds_cost(params: dict, lower_bounds: dict, upper_bounds: dict) -> torch.Tensor:
"""
Impose a cost on parameters that violate bounds constraints

This function works hand-in-hand with :py:func:.make_bounds to enforce greater-than and less-than constraints on parameter values. This is designed to be used as a component of a loss function, to ensure parameter values fall in a reasonable range.

:py:func:.bounds_cost imposes a value of 1.0 for each parameter element that exceeds a bound infinitesimally, increasing exponentially as the bound is exceeded, or 0.0 for each parameter within the bounds. You will most likely want to scale this by a penalty factor within your cost function.

Warnings:
:py:func:.bounds_cost does **not** clip parameters to the bounds. It is possible for parameters to exceed the bounds during optimisation. If this must be prevented, you should clip the parameters explicitly.

See :ref:/in-depth/jax-training.ipynb for examples for using :py:func:.make_bounds and :py:func:.bounds_cost.

Args:
params (dict): A dictionary of parameters over which to impose bounds
lower_bounds (dict): A dictionary of lower bounds for parameters matching your model, modified from that returned by :py:func:.make_bounds
upper_bounds (dict): A dictionary of upper bounds for parameters matching your model, modified from that returned by :py:func:.make_bounds

Returns:
float: The cost to include in the cost function.
"""
# - Flatten all parameter dicts
params, tree_def_params = tu.tree_flatten(params)
lower_bounds, tree_def_minparams = tu.tree_flatten(lower_bounds)
upper_bounds, tree_def_maxparams = tu.tree_flatten(upper_bounds)

if len(params) != len(lower_bounds) != len(upper_bounds):
raise KeyError(
"lower_bounds and upper_bounds must have the same keys as params."
)

# - Define a bounds function
def bound(p, lower, upper):
lb_cost_all = torch.exp(-(p - lower))
ub_cost_all = torch.exp(-(upper - p))

lb_cost = torch.sum(lb_cost_all[p < lower])
ub_cost = torch.sum(ub_cost_all[p > upper])

return lb_cost + ub_cost

# - Map bounds function over parameters and return