Module training.jax_loss
Jax functions useful for training networks using Jax Modules.
See also
See 🏃🏽♀️ Training a Rockpool network with Jax for an introduction to training networks using Jaxbacked modules in Rockpool, including the functions in jax_loss
.
Functions overview

Impose a cost on parameters that violate bounds constraints 

Compute a smooth differentiable approximation to the L0norm 

Compute the mean L2squarednorm of the set of parameters 

Efficient implementation of the log softmax function 

Convenience function to build a bounds template for a problem 

Compute the meansquared error between output and target 

Implements the softmax function 

Compute the sumsquared error between output and target 
Functions
 training.jax_loss.bounds_cost(params: dict, lower_bounds: dict, upper_bounds: dict) float [source]
Impose a cost on parameters that violate bounds constraints
This function works handinhand with
make_bounds()
to enforce greaterthan and lessthan constraints on parameter values. This is designed to be used as a component of a loss function, to ensure parameter values fall in a reasonable range.bounds_cost()
imposes a value of 1.0 for each parameter element that exceeds a bound infinitesimally, increasing exponentially as the bound is exceeded, or 0.0 for each parameter within the bounds. You will most likely want to scale this by a penalty factor within your cost function.Warning
bounds_cost()
does not clip parameters to the bounds. It is possible for parameters to exceed the bounds during optimisation. If this must be prevented, you should clip the parameters explicitly.See also
See 🏃🏽♀️ Training a Rockpool network with Jax for examples for using
make_bounds()
andbounds_cost()
. Parameters
params (dict) – A dictionary of parameters over which to impose bounds
lower_bounds (dict) – A dictionary of lower bounds for parameters matching your model, modified from that returned by
make_bounds()
upper_bounds (dict) – A dictionary of upper bounds for parameters matching your model, modified from that returned by
make_bounds()
 Returns
The cost to include in the cost function.
 Return type
float
 training.jax_loss.l0_norm_approx(params: dict, sigma: float = 0.0001) float [source]
Compute a smooth differentiable approximation to the L0norm
The \(L_0\) norm estimates the sparsity of a vector – i.e. the number of nonzero elements. This function computes a smooth approximation to the \(L_0\) norm, for use as a component in cost functions. Including this cost will encourage parameter sparsity, by penalising nonzero parameters.
The approximation is given by
\[L_0(x) = \frac{x^4}{x^4 + \sigma}\]where \(\sigma\) is a small regularisation value (by default
1e4
).References
Wei et. al 2018. “Gradient Projection with Approximate L0 Norm Minimization for Sparse Reconstruction in Compressed Sensing”, Sensors 18 (3373). doi: 10.3390/s18103373
 Parameters
params (dict) – A parameter dictionary over which to compute the L_0 norm
sigma (float) – A small value to use as a regularisation parameter. Default:
1e4
.
 Returns
The estimated L_0 norm cost
 Return type
float
 training.jax_loss.l2sqr_norm(params: dict) float [source]
Compute the mean L2squarednorm of the set of parameters
This function computes the mean \(L_2^2\) norm of each parameter. The gradient of \(L_2^2(x)\) is defined everywhere, where the gradient of \(L_2(x)\) is not defined at \(x = 0\).
The function is given by
\[L_2^2(x) = E[x^2]\]where \(E[\cdot]\) is the expecation of the expression within the brackets.
 Parameters
params (dict) – A Rockpool parameter dictionary
 Returns
The mean L2sqrnorm of all parameters, computed individually for each parameter
 Return type
float
 training.jax_loss.logsoftmax(x: jax._src.numpy.lax_numpy.ndarray, temperature: float = 1.0) jax._src.numpy.lax_numpy.ndarray [source]
Efficient implementation of the log softmax function
\[ \begin{align}\begin{aligned}log S(x, \tau) = (l / \tau)  \log \Sigma { \exp (l / \tau) }\\l = x  \max (x)\end{aligned}\end{align} \] Parameters
x (np.ndarray) – Input vector of scores
temperature (float) – Temperature \(\tau\) of the softmax. As \(\tau \rightarrow 0\), the function becomes a hard \(\max\) operation. Default:
1.0
.
 Returns
The output of the logsoftmax.
 Return type
np.ndarray
 training.jax_loss.make_bounds(params: dict) Tuple[dict, dict] [source]
Convenience function to build a bounds template for a problem
This function works handinhand with
bounds_cost()
, to enforce minimum and/or maximum parameter bounds.make_bounds()
accepts a set of parameters (e.g. as returned from theModule.parameters()
method), and returns a readymade dictionary of bounds (with no restrictions by default).See also
See 🏃🏽♀️ Training a Rockpool network with Jax for examples for using
make_bounds()
andbounds_cost()
.make_bounds()
returns two dictionaries, representing the lower and upper bounds respectively. Initially all entries will be set tonp.inf
andnp.inf
, indicating that no bounds should be enforced. You must edit these dictionaries to set the bounds. Parameters
params (dict) – Dictionary of parameters defining an optimisation problem. This can be provided as the parameter dictionary returned by
Module.parameters()
. Returns
lower_bounds
,upper_bounds
. Each dictionary mimics the structure ofparams
, with initial bounds set tonp.inf
andnp.inf
(i.e. no bounds enforced). Return type
(dict, dict)
 training.jax_loss.mse(output: jax._src.numpy.lax_numpy.array, target: jax._src.numpy.lax_numpy.array) float [source]
Compute the meansquared error between output and target
This function is designed to be used as a component in a loss function. It computes the meansquared error
\[\textrm{mse}(y, \hat{y}) = { E[{(y  \hat{y})^2}] }\]where \(E[\cdot]\) is the expectation of the expression within the brackets.
 Parameters
output (np.ndarray) – The network output to test, with shape
(T, N)
target (np.ndarray) – The target output, with shape
(T, N)
 Returns
The meansquarederror cost
 Return type
float
 training.jax_loss.softmax(x: jax._src.numpy.lax_numpy.ndarray, temperature: float = 1.0) jax._src.numpy.lax_numpy.ndarray [source]
Implements the softmax function
\[ \begin{align}\begin{aligned}S(x, \tau) = \exp(l / \tau) / { \Sigma { \exp(l / \tau)} }\\l = x  \max(x)\end{aligned}\end{align} \] Parameters
x (np.ndarray) – Input vector of scores
temperature (float) – Temperature \(\tau\) of the softmax. As \(\tau \rightarrow 0\), the function becomes a hard \(\max\) operation. Default:
1.0
.
 Returns
The output of the softmax.
 Return type
np.ndarray
 training.jax_loss.sse(output: jax._src.numpy.lax_numpy.array, target: jax._src.numpy.lax_numpy.array) float [source]
Compute the sumsquared error between output and target
This function is designed to be used as a component in a loss function. It computes the meansquared error
\[\textrm{sse}(y, \hat{y}) = \Sigma {(y  \hat{y})^2}\] Parameters
output (np.ndarray) – The network output to test, with shape
(T, N)
target (np.ndarray) – The target output, with shape
(T, N)
 Returns
The sumsquarederror cost
 Return type
float