Module utilities.jax_tree_utils
Utility functions for working with trees.
Functions overview

Generate all branches (paths from root node to any leaf) in a tree 

Get a value from a tree branch, specifying a branch 

Construct a tree with boolean leaves, for nodes that match a target tree 

Set the values in a full tree inplace, for branches that match a target tree 

Set a value in a tree branch, specifying a branch 

Return the tree branches to tree nodes that evaluate to 

Perform a mapreduce operation over a tree, but only matching selected nodes 

Map a scalar function over a tree, but only matching selected nodes 

Map a scalar function over a tree, but only matching selected nodes. 

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf 
Functions
 utilities.jax_tree_utils.branches(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], prefix: Optional[list] = None) Generator[Tuple, None, None] [source]
Generate all branches (paths from root node to any leaf) in a tree
 Parameters
tree (Tree) – A nested dictionary tree structure
prefix (list) – The current branch prefix. Default:
None
(this is the root)
 Yields
Tuple[str]
 utilities.jax_tree_utils.get_nested(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], branch: Tuple) None [source]
Get a value from a tree branch, specifying a branch
 Parameters
tree (Tree) – A nested dictionary tree structure
branch (Tuple[str]) – A branch: a tuple of indices to walk through the tree
 utilities.jax_tree_utils.make_prototype_tree(full_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], target_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping]) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping] [source]
Construct a tree with boolean leaves, for nodes that match a target tree
Make a prototype tree, indicating which nodes in a large tree should be selected for analysis or processing. This is done on the basis of a smaller “target” tree, which contains only the leaf nodes of interest.
Examples
>>> target_tree = {'a': 0, 'b': {'b2': 0}} >>> full_tree = {'a': 1, 'b': {'b1': 2, 'b2': 3}, 'c': 4, 'd': 5} >>> make_prototype_tree(full_tree, target_tree) {'a': True, 'b': {'b1': False, 'b2': True}, 'c': False, 'd': False}
 Parameters
full_tree (Tree) – A large tree to search through.
target_tree (Tree) – A tree with only few leaf nodes. These nodes will be identifed within the full tree.
 Returns
A nested tree with the same tree structure as
full_tree
, but withbool
leaf nodes. Leaf nodes will beTrue
for branches matching those specified intarget_tree
, andFalse
otherwise. Return type
Tree
 utilities.jax_tree_utils.set_matching(full_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], target_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], value: Any) None [source]
Set the values in a full tree inplace, for branches that match a target tree
 Parameters
full_tree (Tree) – A tree to search over. The values in this tree will be replaced with
value
target_tree (Tree) – A tree that defines the target branches to set in
full_tree
. Matching branches infull_tree
will have their values replaced withvalue
value (Any) – The value to set in
full_tree
.
 utilities.jax_tree_utils.set_nested(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], branch: Tuple, value: Any) None [source]
Set a value in a tree branch, specifying a branch
The leaf node must already exist in the tree.
 Parameters
tree (Tree) – A nested dictionary tree structure
branch (Tuple[str]) – A branch: a tuple of indices to walk through the tree
value (Any) – The value to set at the tree leaf
 utilities.jax_tree_utils.tree_find(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping]) List [source]
Return the tree branches to tree nodes that evaluate to
True
 Parameters
tree (Tree) – A tree to examine
 Returns
A list of all tree branches, for which the corresponding tree leaf evaluate to
True
 Return type
list
 utilities.jax_tree_utils.tree_map_reduce_select(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], protoype_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any], Any], reduce_fun: Callable[[Any, Any], Any], null: Any = nan) Any [source]
Perform a mapreduce operation over a tree, but only matching selected nodes
map_fun()
is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes intree
when the corresponding leaf node inprototype_tree
isTrue
.map_fun()
can return whatever you like; however, you must ensure that thenull
value has the same shape as the value returned bymap_fun()
applied to a leaf node.reduce_fun()
is a reduction function. This is called usingfunctools.reduce
over the mapped outputs ofmap_fun()
.reduce_fun(intermediate_value, map_value)
accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call toreduce()
.Examples
>>> map_fun = np.nanmax >>> >>> def reduce_fun(x, y): >>> return np.nanmax(np.array([x, y])) >>> >>> tree_map_reduce_select({'a': 1}, {'a': 1}, map_fun, reduce_fun) 1
>>> def map_fun(leaf): >>> return {'mean': np.mean(leaf), >>> 'std': np.std(leaf), >>> } >>> >>> def reduce_fun(x, y): >>> return {'mean': np.mean(np.array([x['mean'], y['mean']])), >>> 'std': x['std'] + y['std'], >>> } >>> >>> null = map_fun(np.nan) >>> >>> tree_map_reduce_select({'a': 1}, {'a': 1}, map_fun, reduce_fun, null) {'mean': 1, 'std': NaN}
 Parameters
tree (Tree) – A tree over which to operate
prototype_tree (Tree) – A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) – A function to perform on selected nodes in
tree
.map_fun
has the signaturemap_fun(leaf_node) > value
reduce_fun (Callable[[Value, Value], Value]) – A function that collects two values and returns the combination of the two values, to reduce over the mapped function.
reduce_fun
has the signaturereduce_fun(value, value) > value
.null (Any) – The “null” value to return from the map operation, if the leaf node is not selected in
prototype_tree
. Default:jax.numpy.nan
 Returns
The result of the mapreduce operation over the tree
 Return type
Value
 utilities.jax_tree_utils.tree_map_select(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], prototype_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any], Any]) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping] [source]
Map a scalar function over a tree, but only matching selected nodes
Notes
map_fun
must be a scalar function. This means that if the input is of shape(N, M)
, the output must also be of shape(N, M)
. Otherwise you will get an error. Parameters
tree (Tree) – A tree over which to operate
prototype_tree (Tree) – A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) – A scalar function to perform on selected nodes in
tree
 Returns
A tree with the same structure as
tree
, with leaf nodes replaced with the output ofmap_fun()
for each leaf. Return type
Tree
 utilities.jax_tree_utils.tree_map_select_with_rng(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], prototype_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any, Any], Any], rng_key: Any) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping] [source]
Map a scalar function over a tree, but only matching selected nodes. Includes jaxcompatible random state
Notes
map_fun()
must be a scalar function. This means that if the input is of shape(N, M)
, the output must also be of shape(N, M)
. Otherwise you will get an error. The signature ofmap_fun()
ismap_fun(leaf, rng_key) > value
. Parameters
tree (PyTree) – A tree over which to operate
prototype_tree (PyTree) – A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf, JaxRNGKey], Value]) – A scalar function to perform on selected nodes in
tree
. The second argument is a jax pRNG key to use when generating random state.
 utilities.jax_tree_utils.tree_map_with_rng(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any, Any, Any], Any], rng_key: Any, *rest: Any) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping] [source]
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf
This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then
map_fun
will be called for each leaf, with the signaturemap_fun(leaf_value, rng_key, *rest)
.rest
is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure astree
. See the documentation forjax.tree_util.tree_map
for more information. Parameters
tree (Tree) – A tree to work over
map_fun (Callable[[Value, JaxRNGKey, Any], Value]) – A function to map over the tree. The function must have the signature
map_fun(leaf_value, rng_key, *rest)
rng_key (JaxRNGKey) – An initial RNG key to split
*rest – A tuple of additional
tree
shaped arguments that will be collectively mapped overtree
when callingmap_fun
.
 Returns
The
tree
shaped result of mappingmap_fun
overtree
. Return type
Tree