# Module training.jax_loss

Jax functions useful for training networks using Jax Modules.

See 🏃🏽‍♀️ Training a Rockpool network with Jax for an introduction to training networks using Jax-backed modules in Rockpool, including the functions in jax_loss.

Functions overview

 bounds_cost(params, lower_bounds, upper_bounds) Impose a cost on parameters that violate bounds constraints l0_norm_approx(params[, sigma]) Compute a smooth differentiable approximation to the L0-norm l2sqr_norm(params) Compute the mean L2-squared-norm of the set of parameters logsoftmax(x[, temperature]) Efficient implementation of the log softmax function make_bounds(params) Convenience function to build a bounds template for a problem mse(output, target) Compute the mean-squared error between output and target softmax(x[, temperature]) Implements the softmax function sse(output, target) Compute the sum-squared error between output and target

Functions

training.jax_loss.bounds_cost(params: dict, lower_bounds: dict, upper_bounds: dict) float[source]

Impose a cost on parameters that violate bounds constraints

This function works hand-in-hand with make_bounds() to enforce greater-than and less-than constraints on parameter values. This is designed to be used as a component of a loss function, to ensure parameter values fall in a reasonable range.

bounds_cost() imposes a value of 1.0 for each parameter element that exceeds a bound infinitesimally, increasing exponentially as the bound is exceeded, or 0.0 for each parameter within the bounds. You will most likely want to scale this by a penalty factor within your cost function.

Warning

bounds_cost() does not clip parameters to the bounds. It is possible for parameters to exceed the bounds during optimisation. If this must be prevented, you should clip the parameters explicitly.

Parameters
• params (dict) – A dictionary of parameters over which to impose bounds

• lower_bounds (dict) – A dictionary of lower bounds for parameters matching your model, modified from that returned by make_bounds()

• upper_bounds (dict) – A dictionary of upper bounds for parameters matching your model, modified from that returned by make_bounds()

Returns

The cost to include in the cost function.

Return type

float

training.jax_loss.l0_norm_approx(params: dict, sigma: float = 0.0001) float[source]

Compute a smooth differentiable approximation to the L0-norm

The $$L_0$$ norm estimates the sparsity of a vector – i.e. the number of non-zero elements. This function computes a smooth approximation to the $$L_0$$ norm, for use as a component in cost functions. Including this cost will encourage parameter sparsity, by penalising non-zero parameters.

The approximation is given by

$L_0(x) = \frac{x^4}{x^4 + \sigma}$

where $$\sigma$$ is a small regularisation value (by default 1e-4).

References

Wei et. al 2018. “Gradient Projection with Approximate L0 Norm Minimization for Sparse Reconstruction in Compressed Sensing”, Sensors 18 (3373). doi: 10.3390/s18103373

Parameters
• params (dict) – A parameter dictionary over which to compute the L_0 norm

• sigma (float) – A small value to use as a regularisation parameter. Default: 1e-4.

Returns

The estimated L_0 norm cost

Return type

float

training.jax_loss.l2sqr_norm(params: dict) float[source]

Compute the mean L2-squared-norm of the set of parameters

This function computes the mean $$L_2^2$$ norm of each parameter. The gradient of $$L_2^2(x)$$ is defined everywhere, where the gradient of $$L_2(x)$$ is not defined at $$x = 0$$.

The function is given by

$L_2^2(x) = E[x^2]$

where $$E[\cdot]$$ is the expecation of the expression within the brackets.

Parameters

params (dict) – A Rockpool parameter dictionary

Returns

The mean L2-sqr-norm of all parameters, computed individually for each parameter

Return type

float

training.jax_loss.logsoftmax(x: jax._src.numpy.lax_numpy.ndarray, temperature: float = 1.0) jax._src.numpy.lax_numpy.ndarray[source]

Efficient implementation of the log softmax function

\begin{align}\begin{aligned}log S(x, \tau) = (l / \tau) - \log \Sigma { \exp (l / \tau) }\\l = x - \max (x)\end{aligned}\end{align}
Parameters
• x (np.ndarray) – Input vector of scores

• temperature (float) – Temperature $$\tau$$ of the softmax. As $$\tau \rightarrow 0$$, the function becomes a hard $$\max$$ operation. Default: 1.0.

Returns

The output of the logsoftmax.

Return type

np.ndarray

training.jax_loss.make_bounds(params: dict) Tuple[dict, dict][source]

Convenience function to build a bounds template for a problem

This function works hand-in-hand with bounds_cost(), to enforce minimum and/or maximum parameter bounds. make_bounds() accepts a set of parameters (e.g. as returned from the Module.parameters() method), and returns a ready-made dictionary of bounds (with no restrictions by default).

make_bounds() returns two dictionaries, representing the lower and upper bounds respectively. Initially all entries will be set to -np.inf and np.inf, indicating that no bounds should be enforced. You must edit these dictionaries to set the bounds.

Parameters

params (dict) – Dictionary of parameters defining an optimisation problem. This can be provided as the parameter dictionary returned by Module.parameters().

Returns

lower_bounds, upper_bounds. Each dictionary mimics the structure of params, with initial bounds set to -np.inf and np.inf (i.e. no bounds enforced).

Return type

(dict, dict)

training.jax_loss.mse(output: jax._src.numpy.lax_numpy.array, target: jax._src.numpy.lax_numpy.array) float[source]

Compute the mean-squared error between output and target

This function is designed to be used as a component in a loss function. It computes the mean-squared error

$\textrm{mse}(y, \hat{y}) = { E[{(y - \hat{y})^2}] }$

where $$E[\cdot]$$ is the expectation of the expression within the brackets.

Parameters
• output (np.ndarray) – The network output to test, with shape (T, N)

• target (np.ndarray) – The target output, with shape (T, N)

Returns

The mean-squared-error cost

Return type

float

training.jax_loss.softmax(x: jax._src.numpy.lax_numpy.ndarray, temperature: float = 1.0) jax._src.numpy.lax_numpy.ndarray[source]

Implements the softmax function

\begin{align}\begin{aligned}S(x, \tau) = \exp(l / \tau) / { \Sigma { \exp(l / \tau)} }\\l = x - \max(x)\end{aligned}\end{align}
Parameters
• x (np.ndarray) – Input vector of scores

• temperature (float) – Temperature $$\tau$$ of the softmax. As $$\tau \rightarrow 0$$, the function becomes a hard $$\max$$ operation. Default: 1.0.

Returns

The output of the softmax.

Return type

np.ndarray

training.jax_loss.sse(output: jax._src.numpy.lax_numpy.array, target: jax._src.numpy.lax_numpy.array) float[source]

Compute the sum-squared error between output and target

This function is designed to be used as a component in a loss function. It computes the mean-squared error

$\textrm{sse}(y, \hat{y}) = \Sigma {(y - \hat{y})^2}$
Parameters
• output (np.ndarray) – The network output to test, with shape (T, N)

• target (np.ndarray) – The target output, with shape (T, N)

Returns

The sum-squared-error cost

Return type

float