Module training.jax_loss

Jax functions useful for training networks using Jax Modules.

See also

See 🏃🏽‍♀️ Training a Rockpool network with Jax for an introduction to training networks using Jax-backed modules in Rockpool, including the functions in jax_loss.

Functions overview

bounds_clip(params, lower_bounds, upper_bounds)

bounds_cost(params, lower_bounds, upper_bounds)

Impose a cost on parameters that violate bounds constraints

l0_norm_approx(params[, sigma])

Compute a smooth differentiable approximation to the L0-norm

l2sqr_norm(params)

Compute the mean L2-squared-norm of the set of parameters

logsoftmax(x[, temperature])

Efficient implementation of the log softmax function

make_bounds(params)

Convenience function to build a bounds template for a problem

mse(output, target)

Compute the mean-squared error between output and target

softmax(x[, temperature])

Implements the softmax function

sse(output, target)

Compute the sum-squared error between output and target

Functions

training.jax_loss.bounds_clip(params: dict, lower_bounds: dict, upper_bounds: dict) dict[source]
training.jax_loss.bounds_cost(params: dict, lower_bounds: dict, upper_bounds: dict) float[source]

Impose a cost on parameters that violate bounds constraints

This function works hand-in-hand with make_bounds() to enforce greater-than and less-than constraints on parameter values. This is designed to be used as a component of a loss function, to ensure parameter values fall in a reasonable range.

bounds_cost() imposes a value of 1.0 for each parameter element that exceeds a bound infinitesimally, increasing exponentially as the bound is exceeded, or 0.0 for each parameter within the bounds. You will most likely want to scale this by a penalty factor within your cost function.

Warning

bounds_cost() does not clip parameters to the bounds. It is possible for parameters to exceed the bounds during optimisation. If this must be prevented, you should clip the parameters explicitly.

Parameters:
  • params (dict) – A dictionary of parameters over which to impose bounds

  • lower_bounds (dict) – A dictionary of lower bounds for parameters matching your model, modified from that returned by make_bounds()

  • upper_bounds (dict) – A dictionary of upper bounds for parameters matching your model, modified from that returned by make_bounds()

Returns:

The cost to include in the cost function.

Return type:

float

training.jax_loss.l0_norm_approx(params: dict, sigma: float = 0.0001) float[source]

Compute a smooth differentiable approximation to the L0-norm

The \(L_0\) norm estimates the sparsity of a vector – i.e. the number of non-zero elements. This function computes a smooth approximation to the \(L_0\) norm, for use as a component in cost functions. Including this cost will encourage parameter sparsity, by penalising non-zero parameters.

The approximation is given by

\[L_0(x) = \frac{x^4}{x^4 + \sigma}\]

where \(\sigma`\) is a small regularisation value (by default 1e-4).

References

Wei et. al 2018. “Gradient Projection with Approximate L0 Norm Minimization for Sparse Reconstruction in Compressed Sensing”, Sensors 18 (3373). doi: 10.3390/s18103373

Parameters:
  • params (dict) – A parameter dictionary over which to compute the L_0 norm

  • sigma (float) – A small value to use as a regularisation parameter. Default: 1e-4.

Returns:

The estimated L_0 norm cost

Return type:

float

training.jax_loss.l2sqr_norm(params: dict) float[source]

Compute the mean L2-squared-norm of the set of parameters

This function computes the mean \(L_2^2\) norm of each parameter. The gradient of \(L_2^2(x)\) is defined everywhere, where the gradient of \(L_2(x)\) is not defined at \(x = 0\).

The function is given by

\[L_2^2(x) = E[x^2]\]

where \(E[\cdot]\) is the expecation of the expression within the brackets.

Parameters:

params (dict) – A Rockpool parameter dictionary

Returns:

The mean L2-sqr-norm of all parameters, computed individually for each parameter

Return type:

float

training.jax_loss.logsoftmax(x: Array, temperature: float = 1.0) Array[source]

Efficient implementation of the log softmax function

\[ \begin{align}\begin{aligned}log S(x, \tau) = (l / \tau) - \log \Sigma { \exp (l / \tau) }\\l = x - \max (x)\end{aligned}\end{align} \]
Parameters:
  • x (np.ndarray) – Input vector of scores

  • temperature (float) – Temperature \(\tau\) of the softmax. As \(\tau \rightarrow 0\), the function becomes a hard \(\max\) operation. Default: 1.0.

Returns:

The output of the logsoftmax.

Return type:

np.ndarray

training.jax_loss.make_bounds(params: dict) Tuple[dict, dict][source]

Convenience function to build a bounds template for a problem

This function works hand-in-hand with bounds_cost(), to enforce minimum and/or maximum parameter bounds. make_bounds() accepts a set of parameters (e.g. as returned from the Module.parameters() method), and returns a ready-made dictionary of bounds (with no restrictions by default).

make_bounds() returns two dictionaries, representing the lower and upper bounds respectively. Initially all entries will be set to -np.inf and np.inf, indicating that no bounds should be enforced. You must edit these dictionaries to set the bounds.

Parameters:

params (dict) – Dictionary of parameters defining an optimisation problem. This can be provided as the parameter dictionary returned by Module.parameters().

Returns:

lower_bounds, upper_bounds. Each dictionary mimics the structure of params, with initial bounds set to -np.inf and np.inf (i.e. no bounds enforced).

Return type:

(dict, dict)

training.jax_loss.mse(output: array, target: array) float[source]

Compute the mean-squared error between output and target

This function is designed to be used as a component in a loss function. It computes the mean-squared error

\[\textrm{mse}(y, \hat{y}) = { E[{(y - \hat{y})^2}] }\]

where \(E[\cdot]\) is the expectation of the expression within the brackets.

Parameters:
  • output (np.ndarray) – The network output to test, with shape (T, N)

  • target (np.ndarray) – The target output, with shape (T, N)

Returns:

The mean-squared-error cost

Return type:

float

training.jax_loss.softmax(x: Array, temperature: float = 1.0) Array[source]

Implements the softmax function

\[ \begin{align}\begin{aligned}S(x, \tau) = \exp(l / \tau) / { \Sigma { \exp(l / \tau)} }\\l = x - \max(x)\end{aligned}\end{align} \]
Parameters:
  • x (np.ndarray) – Input vector of scores

  • temperature (float) – Temperature \(\tau\) of the softmax. As \(\tau \rightarrow 0\), the function becomes a hard \(\max\) operation. Default: 1.0.

Returns:

The output of the softmax.

Return type:

np.ndarray

training.jax_loss.sse(output: array, target: array) float[source]

Compute the sum-squared error between output and target

This function is designed to be used as a component in a loss function. It computes the mean-squared error

\[\textrm{sse}(y, \hat{y}) = \Sigma {(y - \hat{y})^2}\]
Parameters:
  • output (np.ndarray) – The network output to test, with shape (T, N)

  • target (np.ndarray) – The target output, with shape (T, N)

Returns:

The sum-squared-error cost

Return type:

float