rockpool.training.jax_debug

Utilities for debugging Jax training loops

Functions

debug_evolution(jmod, state, parameters, input)

Debug and report the presence of NaNs in network state / output

debug_optimisation(jmod, parameters, input, ...)

Debug an optimisation step, reporting the presence of NaNs in loss and gradients

flatten(generic_collection[, sep])

Flattens a generic collection of collections into an ordered dictionary.