Module training.jax_loss
Jax functions useful for training networks using Jax Modules.
See also
See 🏃🏽♀️ Training a Rockpool network with Jax for an introduction to training networks using Jax-backed modules in Rockpool, including the functions in jax_loss
.
Functions overview
|
|
|
Impose a cost on parameters that violate bounds constraints |
|
Compute a smooth differentiable approximation to the L0-norm |
|
Compute the mean L2-squared-norm of the set of parameters |
|
Efficient implementation of the log softmax function |
|
Convenience function to build a bounds template for a problem |
|
Compute the mean-squared error between output and target |
|
Implements the softmax function |
|
Compute the sum-squared error between output and target |
Functions
- training.jax_loss.bounds_cost(params: dict, lower_bounds: dict, upper_bounds: dict) float [source]
Impose a cost on parameters that violate bounds constraints
This function works hand-in-hand with
make_bounds()
to enforce greater-than and less-than constraints on parameter values. This is designed to be used as a component of a loss function, to ensure parameter values fall in a reasonable range.bounds_cost()
imposes a value of 1.0 for each parameter element that exceeds a bound infinitesimally, increasing exponentially as the bound is exceeded, or 0.0 for each parameter within the bounds. You will most likely want to scale this by a penalty factor within your cost function.Warning
bounds_cost()
does not clip parameters to the bounds. It is possible for parameters to exceed the bounds during optimisation. If this must be prevented, you should clip the parameters explicitly.See also
See 🏃🏽♀️ Training a Rockpool network with Jax for examples for using
make_bounds()
andbounds_cost()
.- Parameters:
params (dict) – A dictionary of parameters over which to impose bounds
lower_bounds (dict) – A dictionary of lower bounds for parameters matching your model, modified from that returned by
make_bounds()
upper_bounds (dict) – A dictionary of upper bounds for parameters matching your model, modified from that returned by
make_bounds()
- Returns:
The cost to include in the cost function.
- Return type:
float
- training.jax_loss.l0_norm_approx(params: dict, sigma: float = 0.0001) float [source]
Compute a smooth differentiable approximation to the L0-norm
The \(L_0\) norm estimates the sparsity of a vector – i.e. the number of non-zero elements. This function computes a smooth approximation to the \(L_0\) norm, for use as a component in cost functions. Including this cost will encourage parameter sparsity, by penalising non-zero parameters.
The approximation is given by
\[L_0(x) = \frac{x^4}{x^4 + \sigma}\]where \(\sigma`\) is a small regularisation value (by default
1e-4
).References
Wei et. al 2018. “Gradient Projection with Approximate L0 Norm Minimization for Sparse Reconstruction in Compressed Sensing”, Sensors 18 (3373). doi: 10.3390/s18103373
- Parameters:
params (dict) – A parameter dictionary over which to compute the L_0 norm
sigma (float) – A small value to use as a regularisation parameter. Default:
1e-4
.
- Returns:
The estimated L_0 norm cost
- Return type:
float
- training.jax_loss.l2sqr_norm(params: dict) float [source]
Compute the mean L2-squared-norm of the set of parameters
This function computes the mean \(L_2^2\) norm of each parameter. The gradient of \(L_2^2(x)\) is defined everywhere, where the gradient of \(L_2(x)\) is not defined at \(x = 0\).
The function is given by
\[L_2^2(x) = E[x^2]\]where \(E[\cdot]\) is the expecation of the expression within the brackets.
- Parameters:
params (dict) – A Rockpool parameter dictionary
- Returns:
The mean L2-sqr-norm of all parameters, computed individually for each parameter
- Return type:
float
- training.jax_loss.logsoftmax(x: Array, temperature: float = 1.0) Array [source]
Efficient implementation of the log softmax function
\[ \begin{align}\begin{aligned}log S(x, \tau) = (l / \tau) - \log \Sigma { \exp (l / \tau) }\\l = x - \max (x)\end{aligned}\end{align} \]- Parameters:
x (np.ndarray) – Input vector of scores
temperature (float) – Temperature \(\tau\) of the softmax. As \(\tau \rightarrow 0\), the function becomes a hard \(\max\) operation. Default:
1.0
.
- Returns:
The output of the logsoftmax.
- Return type:
np.ndarray
- training.jax_loss.make_bounds(params: dict) Tuple[dict, dict] [source]
Convenience function to build a bounds template for a problem
This function works hand-in-hand with
bounds_cost()
, to enforce minimum and/or maximum parameter bounds.make_bounds()
accepts a set of parameters (e.g. as returned from theModule.parameters()
method), and returns a ready-made dictionary of bounds (with no restrictions by default).See also
See 🏃🏽♀️ Training a Rockpool network with Jax for examples for using
make_bounds()
andbounds_cost()
.make_bounds()
returns two dictionaries, representing the lower and upper bounds respectively. Initially all entries will be set to-np.inf
andnp.inf
, indicating that no bounds should be enforced. You must edit these dictionaries to set the bounds.- Parameters:
params (dict) – Dictionary of parameters defining an optimisation problem. This can be provided as the parameter dictionary returned by
Module.parameters()
.- Returns:
lower_bounds
,upper_bounds
. Each dictionary mimics the structure ofparams
, with initial bounds set to-np.inf
andnp.inf
(i.e. no bounds enforced).- Return type:
(dict, dict)
- training.jax_loss.mse(output: array, target: array) float [source]
Compute the mean-squared error between output and target
This function is designed to be used as a component in a loss function. It computes the mean-squared error
\[\textrm{mse}(y, \hat{y}) = { E[{(y - \hat{y})^2}] }\]where \(E[\cdot]\) is the expectation of the expression within the brackets.
- Parameters:
output (np.ndarray) – The network output to test, with shape
(T, N)
target (np.ndarray) – The target output, with shape
(T, N)
- Returns:
The mean-squared-error cost
- Return type:
float
- training.jax_loss.softmax(x: Array, temperature: float = 1.0) Array [source]
Implements the softmax function
\[ \begin{align}\begin{aligned}S(x, \tau) = \exp(l / \tau) / { \Sigma { \exp(l / \tau)} }\\l = x - \max(x)\end{aligned}\end{align} \]- Parameters:
x (np.ndarray) – Input vector of scores
temperature (float) – Temperature \(\tau\) of the softmax. As \(\tau \rightarrow 0\), the function becomes a hard \(\max\) operation. Default:
1.0
.
- Returns:
The output of the softmax.
- Return type:
np.ndarray
- training.jax_loss.sse(output: array, target: array) float [source]
Compute the sum-squared error between output and target
This function is designed to be used as a component in a loss function. It computes the mean-squared error
\[\textrm{sse}(y, \hat{y}) = \Sigma {(y - \hat{y})^2}\]- Parameters:
output (np.ndarray) – The network output to test, with shape
(T, N)
target (np.ndarray) – The target output, with shape
(T, N)
- Returns:
The sum-squared-error cost
- Return type:
float