rockpool.training.jax_loss
Jax functions useful for training networks using Jax Modules.
See also
See 🏃🏽♀️ Training a Rockpool network with Jax for an introduction to training networks using Jax-backed modules in Rockpool, including the functions in jax_loss
.
Functions
|
|
|
Impose a cost on parameters that violate bounds constraints |
|
Compute a smooth differentiable approximation to the L0-norm |
|
Compute the mean L2-squared-norm of the set of parameters |
|
Efficient implementation of the log softmax function |
|
Convenience function to build a bounds template for a problem |
|
Compute the mean-squared error between output and target |
|
Implements the softmax function |
|
Compute the sum-squared error between output and target |