rockpool.training.jax_loss
Jax functions useful for training networks using Jax Modules.
See also
See 🏃🏽♀️ Training a Rockpool network with Jax for an introduction to training networks using Jaxbacked modules in Rockpool, including the functions in jax_loss
.
Functions

Impose a cost on parameters that violate bounds constraints 

Compute a smooth differentiable approximation to the L0norm 

Compute the mean L2squarednorm of the set of parameters 

Efficient implementation of the log softmax function 

Convenience function to build a bounds template for a problem 

Compute the meansquared error between output and target 

Implements the softmax function 

Compute the sumsquared error between output and target 