Module utilities.jax_tree_utils

Utility functions for working with trees.

Functions overview

tree_map_reduce_select(tree, protoype_tree, ...)

Perform a map-reduce operation over a tree, but only matching selected nodes

tree_map_select(tree, prototype_tree, map_fun)

Map a scalar function over a tree, but only matching selected nodes

tree_map_select_with_rng(tree, ...)

Map a scalar function over a tree, but only matching selected nodes.

tree_map_with_rng(tree, map_fun, rng_key, *rest)

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf

Functions

utilities.jax_tree_utils.tree_map_reduce_select(tree: Iterable | MutableMapping | Mapping, protoype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any], Any], reduce_fun: Callable[[Any, Any], Any], null: Any = nan) Any[source]

Perform a map-reduce operation over a tree, but only matching selected nodes

map_fun() is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes in tree when the corresponding leaf node in prototype_tree is True.

map_fun() can return whatever you like; however, you must ensure that the null value has the same shape as the value returned by map_fun() applied to a leaf node.

reduce_fun() is a reduction function. This is called using functools.reduce over the mapped outputs of map_fun().

reduce_fun(intermediate_value, map_value) accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call to reduce().

Examples

>>> map_fun = np.nanmax
>>>
>>> def reduce_fun(x, y):
>>>     return np.nanmax(np.array([x, y]))
>>>
>>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun)
1
>>> def map_fun(leaf):
>>>     return {'mean': np.mean(leaf),
>>>             'std': np.std(leaf),
>>>             }
>>>
>>> def reduce_fun(x, y):
>>>     return {'mean': np.mean(np.array([x['mean'], y['mean']])),
>>>             'std': x['std'] + y['std'],
>>>             }
>>>
>>> null = map_fun(np.nan)
>>>
>>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun, null)
{'mean': 1, 'std': NaN}
Parameters:
  • tree (Tree) – A tree over which to operate

  • prototype_tree (Tree) – A prototype tree, with structure matching that of tree, but with bool leaves. Only leaf nodes in tree with True in the corresponding prototype tree node will be modified.

  • map_fun (Callable[[Leaf], Value]) – A function to perform on selected nodes in tree. map_fun has the signature map_fun(leaf_node) -> value

  • reduce_fun (Callable[[Value, Value], Value]) – A function that collects two values and returns the combination of the two values, to reduce over the mapped function. reduce_fun has the signature reduce_fun(value, value) -> value.

  • null (Any) – The “null” value to return from the map operation, if the leaf node is not selected in prototype_tree. Default: jax.numpy.nan

Returns:

The result of the map-reduce operation over the tree

Return type:

Value

utilities.jax_tree_utils.tree_map_select(tree: Iterable | MutableMapping | Mapping, prototype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any], Any]) Iterable | MutableMapping | Mapping[source]

Map a scalar function over a tree, but only matching selected nodes

Notes

map_fun must be a scalar function. This means that if the input is of shape (N, M), the output must also be of shape (N, M). Otherwise you will get an error.

Parameters:
  • tree (Tree) – A tree over which to operate

  • prototype_tree (Tree) – A prototype tree, with structure matching that of tree, but with bool leaves. Only leaf nodes in tree with True in the corresponding prototype tree node will be modified.

  • map_fun (Callable[[Leaf], Value]) – A scalar function to perform on selected nodes in tree

Returns:

A tree with the same structure as tree, with leaf nodes replaced with the output of map_fun() for each leaf.

Return type:

Tree

utilities.jax_tree_utils.tree_map_select_with_rng(tree: Iterable | MutableMapping | Mapping, prototype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any, Any], Any], rng_key: Any) Iterable | MutableMapping | Mapping[source]

Map a scalar function over a tree, but only matching selected nodes. Includes jax-compatible random state

Notes

map_fun() must be a scalar function. This means that if the input is of shape (N, M), the output must also be of shape (N, M). Otherwise you will get an error. The signature of map_fun() is map_fun(leaf, rng_key) -> value.

Parameters:
  • tree (PyTree) – A tree over which to operate

  • prototype_tree (PyTree) – A prototype tree, with structure matching that of tree, but with bool leaves. Only leaf nodes in tree with True in the corresponding prototype tree node will be modified.

  • map_fun (Callable[[Leaf, JaxRNGKey], Value]) – A scalar function to perform on selected nodes in tree. The second argument is a jax pRNG key to use when generating random state.

utilities.jax_tree_utils.tree_map_with_rng(tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any, Any, Any], Any], rng_key: Any, *rest: Any) Iterable | MutableMapping | Mapping[source]

Perform a multimap over a tree, splitting and inserting an RNG key for each leaf

This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then map_fun will be called for each leaf, with the signature map_fun(leaf_value, rng_key, *rest).

rest is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure as tree. See the documentation for jax.tree_util.tree_map for more information.

Parameters:
  • tree (Tree) – A tree to work over

  • map_fun (Callable[[Value, JaxRNGKey, Any], Value]) – A function to map over the tree. The function must have the signature map_fun(leaf_value, rng_key, *rest)

  • rng_key (JaxRNGKey) – An initial RNG key to split

  • *rest – A tuple of additional tree-shaped arguments that will be collectively mapped over tree when calling map_fun.

Returns:

The tree-shaped result of mapping map_fun over tree.

Return type:

Tree