rockpool.utilities.jax_tree_utils

Utility functions for working with trees.

Functions

tree_map_reduce_select(tree, protoype_tree, ...)

Perform a map-reduce operation over a tree, but only matching selected nodes

tree_map_select(tree, prototype_tree, map_fun)

Map a scalar function over a tree, but only matching selected nodes

tree_map_select_with_rng(tree, ...)

Map a scalar function over a tree, but only matching selected nodes.

tree_map_with_rng(tree, map_fun, rng_key, *rest)

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf