Jax functions useful for training networks using Jax Modules.

See also

See 🏃🏽‍♀️ Training a Rockpool network with Jax for an introduction to training networks using Jax-backed modules in Rockpool, including the functions in jax_loss.


bounds_cost(params, lower_bounds, upper_bounds)

Impose a cost on parameters that violate bounds constraints

l0_norm_approx(params[, sigma])

Compute a smooth differentiable approximation to the L0-norm


Compute the mean L2-squared-norm of the set of parameters

logsoftmax(x[, temperature])

Efficient implementation of the log softmax function


Convenience function to build a bounds template for a problem

mse(output, target)

Compute the mean-squared error between output and target

softmax(x[, temperature])

Implements the softmax function

sse(output, target)

Compute the sum-squared error between output and target