Module utilities.jax_tree_utilsο
Utility functions for working with trees.
Functions overview

Perform a mapreduce operation over a tree, but only matching selected nodes 

Map a scalar function over a tree, but only matching selected nodes 

Map a scalar function over a tree, but only matching selected nodes. 

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf 
Functions
 utilities.jax_tree_utils.tree_map_reduce_select(tree: Iterable  MutableMapping  Mapping, protoype_tree: Iterable  MutableMapping  Mapping, map_fun: Callable[[Any], Any], reduce_fun: Callable[[Any, Any], Any], null: Any = nan) Any [source]ο
Perform a mapreduce operation over a tree, but only matching selected nodes
map_fun()
is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes intree
when the corresponding leaf node inprototype_tree
isTrue
.map_fun()
can return whatever you like; however, you must ensure that thenull
value has the same shape as the value returned bymap_fun()
applied to a leaf node.reduce_fun()
is a reduction function. This is called usingfunctools.reduce
over the mapped outputs ofmap_fun()
.reduce_fun(intermediate_value, map_value)
accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call toreduce()
.Examples
>>> map_fun = np.nanmax >>> >>> def reduce_fun(x, y): >>> return np.nanmax(np.array([x, y])) >>> >>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun) 1
>>> def map_fun(leaf): >>> return {'mean': np.mean(leaf), >>> 'std': np.std(leaf), >>> } >>> >>> def reduce_fun(x, y): >>> return {'mean': np.mean(np.array([x['mean'], y['mean']])), >>> 'std': x['std'] + y['std'], >>> } >>> >>> null = map_fun(np.nan) >>> >>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun, null) {'mean': 1, 'std': NaN}
 Parameters:
tree (Tree) β A tree over which to operate
prototype_tree (Tree) β A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) β A function to perform on selected nodes in
tree
.map_fun
has the signaturemap_fun(leaf_node) > value
reduce_fun (Callable[[Value, Value], Value]) β A function that collects two values and returns the combination of the two values, to reduce over the mapped function.
reduce_fun
has the signaturereduce_fun(value, value) > value
.null (Any) β The βnullβ value to return from the map operation, if the leaf node is not selected in
prototype_tree
. Default:jax.numpy.nan
 Returns:
The result of the mapreduce operation over the tree
 Return type:
Value
 utilities.jax_tree_utils.tree_map_select(tree: Iterable  MutableMapping  Mapping, prototype_tree: Iterable  MutableMapping  Mapping, map_fun: Callable[[Any], Any]) Iterable  MutableMapping  Mapping [source]ο
Map a scalar function over a tree, but only matching selected nodes
Notes
map_fun
must be a scalar function. This means that if the input is of shape(N, M)
, the output must also be of shape(N, M)
. Otherwise you will get an error. Parameters:
tree (Tree) β A tree over which to operate
prototype_tree (Tree) β A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) β A scalar function to perform on selected nodes in
tree
 Returns:
A tree with the same structure as
tree
, with leaf nodes replaced with the output ofmap_fun()
for each leaf. Return type:
Tree
 utilities.jax_tree_utils.tree_map_select_with_rng(tree: Iterable  MutableMapping  Mapping, prototype_tree: Iterable  MutableMapping  Mapping, map_fun: Callable[[Any, Any], Any], rng_key: Any) Iterable  MutableMapping  Mapping [source]ο
Map a scalar function over a tree, but only matching selected nodes. Includes jaxcompatible random state
Notes
map_fun()
must be a scalar function. This means that if the input is of shape(N, M)
, the output must also be of shape(N, M)
. Otherwise you will get an error. The signature ofmap_fun()
ismap_fun(leaf, rng_key) > value
. Parameters:
tree (PyTree) β A tree over which to operate
prototype_tree (PyTree) β A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf, JaxRNGKey], Value]) β A scalar function to perform on selected nodes in
tree
. The second argument is a jax pRNG key to use when generating random state.
 utilities.jax_tree_utils.tree_map_with_rng(tree: Iterable  MutableMapping  Mapping, map_fun: Callable[[Any, Any, Any], Any], rng_key: Any, *rest: Any) Iterable  MutableMapping  Mapping [source]ο
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf
This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then
map_fun
will be called for each leaf, with the signaturemap_fun(leaf_value, rng_key, *rest)
.rest
is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure astree
. See the documentation forjax.tree_util.tree_map
for more information. Parameters:
tree (Tree) β A tree to work over
map_fun (Callable[[Value, JaxRNGKey, Any], Value]) β A function to map over the tree. The function must have the signature
map_fun(leaf_value, rng_key, *rest)
rng_key (JaxRNGKey) β An initial RNG key to split
*rest β A tuple of additional
tree
shaped arguments that will be collectively mapped overtree
when callingmap_fun
.
 Returns:
The
tree
shaped result of mappingmap_fun
overtree
. Return type:
Tree