Module utilities.jax_tree_utils

Utility functions for working with trees.

Functions overview

branches(tree[, prefix])

Generate all branches (paths from root node to any leaf) in a tree

get_nested(tree, branch)

Get a value from a tree branch, specifying a branch

make_prototype_tree(full_tree, target_tree)

Construct a tree with boolean leaves, for nodes that match a target tree

set_matching(full_tree, target_tree, value)

Set the values in a full tree in-place, for branches that match a target tree

set_nested(tree, branch, value)

Set a value in a tree branch, specifying a branch

tree_find(tree)

Return the tree branches to tree nodes that evaluate to True

tree_map_reduce_select(tree, protoype_tree, ...)

Perform a map-reduce operation over a tree, but only matching selected nodes

tree_map_select(tree, prototype_tree, map_fun)

Map a scalar function over a tree, but only matching selected nodes

tree_map_select_with_rng(tree, ...)

Map a scalar function over a tree, but only matching selected nodes.

tree_map_with_rng(tree, map_fun, rng_key, *rest)

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf

Functions

utilities.jax_tree_utils.branches(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], prefix: Optional[list] = None) Generator[Tuple, None, None][source]

Generate all branches (paths from root node to any leaf) in a tree

Parameters
  • tree (Tree) – A nested dictionary tree structure

  • prefix (list) – The current branch prefix. Default: None (this is the root)

Yields

Tuple[str]

utilities.jax_tree_utils.get_nested(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], branch: Tuple) None[source]

Get a value from a tree branch, specifying a branch

Parameters
  • tree (Tree) – A nested dictionary tree structure

  • branch (Tuple[str]) – A branch: a tuple of indices to walk through the tree

utilities.jax_tree_utils.make_prototype_tree(full_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], target_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping]) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping][source]

Construct a tree with boolean leaves, for nodes that match a target tree

Make a prototype tree, indicating which nodes in a large tree should be selected for analysis or processing. This is done on the basis of a smaller “target” tree, which contains only the leaf nodes of interest.

Examples

>>> target_tree = {'a': 0, 'b': {'b2': 0}}
>>> full_tree = {'a': 1, 'b': {'b1': 2, 'b2': 3}, 'c': 4, 'd': 5}
>>> make_prototype_tree(full_tree, target_tree)
{'a': True, 'b': {'b1': False, 'b2': True}, 'c': False, 'd': False}
Parameters
  • full_tree (Tree) – A large tree to search through.

  • target_tree (Tree) – A tree with only few leaf nodes. These nodes will be identifed within the full tree.

Returns

A nested tree with the same tree structure as full_tree, but with bool leaf nodes. Leaf nodes will be True for branches matching those specified in target_tree, and False otherwise.

Return type

Tree

utilities.jax_tree_utils.set_matching(full_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], target_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], value: Any) None[source]

Set the values in a full tree in-place, for branches that match a target tree

Parameters
  • full_tree (Tree) – A tree to search over. The values in this tree will be replaced with value

  • target_tree (Tree) – A tree that defines the target branches to set in full_tree. Matching branches in full_tree will have their values replaced with value

  • value (Any) – The value to set in full_tree.

utilities.jax_tree_utils.set_nested(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], branch: Tuple, value: Any) None[source]

Set a value in a tree branch, specifying a branch

The leaf node must already exist in the tree.

Parameters
  • tree (Tree) – A nested dictionary tree structure

  • branch (Tuple[str]) – A branch: a tuple of indices to walk through the tree

  • value (Any) – The value to set at the tree leaf

utilities.jax_tree_utils.tree_find(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping]) List[source]

Return the tree branches to tree nodes that evaluate to True

Parameters

tree (Tree) – A tree to examine

Returns

A list of all tree branches, for which the corresponding tree leaf evaluate to True

Return type

list

utilities.jax_tree_utils.tree_map_reduce_select(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], protoype_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any], Any], reduce_fun: Callable[[Any, Any], Any], null: Any = nan) Any[source]

Perform a map-reduce operation over a tree, but only matching selected nodes

map_fun() is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes in tree when the corresponding leaf node in prototype_tree is True.

map_fun() can return whatever you like; however, you must ensure that the null value has the same shape as the value returned by map_fun() applied to a leaf node.

reduce_fun() is a reduction function. This is called using functools.reduce over the mapped outputs of map_fun().

reduce_fun(intermediate_value, map_value) accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call to reduce().

Examples

>>> map_fun = np.nanmax
>>>
>>> def reduce_fun(x, y):
>>>     return np.nanmax(np.array([x, y]))
>>>
>>> tree_map_reduce_select({'a': 1}, {'a': 1}, map_fun, reduce_fun)
1
>>> def map_fun(leaf):
>>>     return {'mean': np.mean(leaf),
>>>             'std': np.std(leaf),
>>>             }
>>>
>>> def reduce_fun(x, y):
>>>     return {'mean': np.mean(np.array([x['mean'], y['mean']])),
>>>             'std': x['std'] + y['std'],
>>>             }
>>>
>>> null = map_fun(np.nan)
>>>
>>> tree_map_reduce_select({'a': 1}, {'a': 1}, map_fun, reduce_fun, null)
{'mean': 1, 'std': NaN}
Parameters
  • tree (Tree) – A tree over which to operate

  • prototype_tree (Tree) – A prototype tree, with structure matching that of tree, but with bool leaves. Only leaf nodes in tree with True in the corresponding prototype tree node will be modified.

  • map_fun (Callable[[Leaf], Value]) – A function to perform on selected nodes in tree. map_fun has the signature map_fun(leaf_node) -> value

  • reduce_fun (Callable[[Value, Value], Value]) – A function that collects two values and returns the combination of the two values, to reduce over the mapped function. reduce_fun has the signature reduce_fun(value, value) -> value.

  • null (Any) – The “null” value to return from the map operation, if the leaf node is not selected in prototype_tree. Default: jax.numpy.nan

Returns

The result of the map-reduce operation over the tree

Return type

Value

utilities.jax_tree_utils.tree_map_select(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], prototype_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any], Any]) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping][source]

Map a scalar function over a tree, but only matching selected nodes

Notes

map_fun must be a scalar function. This means that if the input is of shape (N, M), the output must also be of shape (N, M). Otherwise you will get an error.

Parameters
  • tree (Tree) – A tree over which to operate

  • prototype_tree (Tree) – A prototype tree, with structure matching that of tree, but with bool leaves. Only leaf nodes in tree with True in the corresponding prototype tree node will be modified.

  • map_fun (Callable[[Leaf], Value]) – A scalar function to perform on selected nodes in tree

Returns

A tree with the same structure as tree, with leaf nodes replaced with the output of map_fun() for each leaf.

Return type

Tree

utilities.jax_tree_utils.tree_map_select_with_rng(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], prototype_tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any, Any], Any], rng_key: Any) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping][source]

Map a scalar function over a tree, but only matching selected nodes. Includes jax-compatible random state

Notes

map_fun() must be a scalar function. This means that if the input is of shape (N, M), the output must also be of shape (N, M). Otherwise you will get an error. The signature of map_fun() is map_fun(leaf, rng_key) -> value.

Parameters
  • tree (PyTree) – A tree over which to operate

  • prototype_tree (PyTree) – A prototype tree, with structure matching that of tree, but with bool leaves. Only leaf nodes in tree with True in the corresponding prototype tree node will be modified.

  • map_fun (Callable[[Leaf, JaxRNGKey], Value]) – A scalar function to perform on selected nodes in tree. The second argument is a jax pRNG key to use when generating random state.

utilities.jax_tree_utils.tree_map_with_rng(tree: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], map_fun: Callable[[Any, Any, Any], Any], rng_key: Any, *rest: Any) Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping][source]

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf

This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then map_fun will be called for each leaf, with the signature map_fun(leaf_value, rng_key, *rest).

rest is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure as tree. See the documentation for jax.tree_util.tree_map for more information.

Parameters
  • tree (Tree) – A tree to work over

  • map_fun (Callable[[Value, JaxRNGKey, Any], Value]) – A function to map over the tree. The function must have the signature map_fun(leaf_value, rng_key, *rest)

  • rng_key (JaxRNGKey) – An initial RNG key to split

  • *rest – A tuple of additional tree-shaped arguments that will be collectively mapped over tree when calling map_fun.

Returns

The tree-shaped result of mapping map_fun over tree.

Return type

Tree