# Source code for utilities.jax_tree_utils

```
"""
Utility functions for working with trees.
"""
from typing import Any, Callable
import jax
import jax.tree_util as tu
import functools
# - Set up some useful types
from rockpool.typehints import Tree, Leaf, Value, JaxRNGKey
from .tree_utils import *
__all__ = [
"tree_map_reduce_select",
"tree_map_select",
"tree_map_select_with_rng",
"tree_map_with_rng",
"branches",
"get_nested",
"set_nested",
"set_matching",
"set_matching_select",
"make_prototype_tree",
"tree_map",
"tree_flatten",
"tree_unflatten",
"tree_update",
"tree_find",
]
[docs]def tree_map_reduce_select(
tree: Tree,
protoype_tree: Tree,
map_fun: Callable[[Leaf], Value],
reduce_fun: Callable[[Value, Value], Value],
null: Any = jax.numpy.nan,
) -> Value:
"""
Perform a map-reduce operation over a tree, but only matching selected nodes
`map_fun()` is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes in `tree` when the corresponding leaf node in `prototype_tree` is ``True``.
`map_fun()` can return whatever you like; however, you must ensure that the `null` value has the same shape as the value returned by `map_fun()` applied to a leaf node.
`reduce_fun()` is a reduction function. This is called using `functools.reduce` over the mapped outputs of `map_fun()`.
`reduce_fun(intermediate_value, map_value)` accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call to `reduce()`.
Examples:
>>> map_fun = np.nanmax
>>>
>>> def reduce_fun(x, y):
>>> return np.nanmax(np.array([x, y]))
>>>
>>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun)
1
>>> def map_fun(leaf):
>>> return {'mean': np.mean(leaf),
>>> 'std': np.std(leaf),
>>> }
>>>
>>> def reduce_fun(x, y):
>>> return {'mean': np.mean(np.array([x['mean'], y['mean']])),
>>> 'std': x['std'] + y['std'],
>>> }
>>>
>>> null = map_fun(np.nan)
>>>
>>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun, null)
{'mean': 1, 'std': NaN}
Args:
tree (Tree): A tree over which to operate
prototype_tree (Tree): A prototype tree, with structure matching that of `tree`, but with ``bool`` leaves. Only leaf nodes in `tree` with ``True`` in the corresponding prototype tree node will be modified.
map_fun (Callable[[Leaf], Value]): A function to perform on selected nodes in `tree`. `map_fun` has the signature `map_fun(leaf_node) -> value`
reduce_fun (Callable[[Value, Value], Value]): A function that collects two values and returns the combination of the two values, to reduce over the mapped function. `reduce_fun` has the signature `reduce_fun(value, value) -> value`.
null (Any): The "null" value to return from the map operation, if the leaf node is not selected in `prototype_tree`. Default: ``jax.numpy.nan``
Returns:
Value: The result of the map-reduce operation over the tree
"""
tree_flat, _ = tu.tree_flatten(tree)
proto_flat, _ = tu.tree_flatten(protoype_tree)
def map_or_null(leaf: Any, select: bool) -> Any:
return jax.lax.cond(
select,
lambda _: map_fun(leaf),
lambda _: null,
0,
)
# - Map function over leaves
mapped = [map_or_null(*xs) for xs in zip(tree_flat, proto_flat)]
# - Reduce function over leaves
return functools.reduce(reduce_fun, mapped)
[docs]def tree_map_select(
tree: Tree, prototype_tree: Tree, map_fun: Callable[[Leaf], Value]
) -> Tree:
"""
Map a scalar function over a tree, but only matching selected nodes
Notes:
`map_fun` must be a scalar function. This means that if the input is of shape ``(N, M)``, the output must also be of shape ``(N, M)``. Otherwise you will get an error.
Args:
tree (Tree): A tree over which to operate
prototype_tree (Tree): A prototype tree, with structure matching that of `tree`, but with ``bool`` leaves. Only leaf nodes in `tree` with ``True`` in the corresponding prototype tree node will be modified.
map_fun (Callable[[Leaf], Value]): A scalar function to perform on selected nodes in `tree`
Returns:
Tree: A tree with the same structure as `tree`, with leaf nodes replaced with the output of `map_fun()` for each leaf.
"""
# - Flatten both trees
tree_flat, treedef = tu.tree_flatten(tree)
proto_flat, _ = tu.tree_flatten(prototype_tree)
# - A function that conditionally maps over the tree leaves
def map_or_original(leaf: Any, select: bool) -> Any:
return jax.lax.cond(
select,
lambda _: map_fun(leaf),
lambda _: leaf,
0,
)
# - Map function over leaves
mapped = [map_or_original(*xs) for xs in zip(tree_flat, proto_flat)]
# - Return tree
return tu.tree_unflatten(treedef, mapped)
[docs]def tree_map_select_with_rng(
tree: Tree,
prototype_tree: Tree,
map_fun: Callable[[Leaf, JaxRNGKey], Value],
rng_key: JaxRNGKey,
) -> Tree:
"""
Map a scalar function over a tree, but only matching selected nodes. Includes jax-compatible random state
Notes:
`map_fun()` must be a scalar function. This means that if the input is of shape ``(N, M)``, the output must also be of shape ``(N, M)``. Otherwise you will get an error.
The signature of `map_fun()` is `map_fun(leaf, rng_key) -> value`.
Args:
tree (PyTree): A tree over which to operate
prototype_tree (PyTree): A prototype tree, with structure matching that of `tree`, but with ``bool`` leaves. Only leaf nodes in `tree` with ``True`` in the corresponding prototype tree node will be modified.
map_fun (Callable[[Leaf, JaxRNGKey], Value]): A scalar function to perform on selected nodes in `tree`. The second argument is a jax pRNG key to use when generating random state.
"""
# - Flatten both trees
tree_flat, treedef = tu.tree_flatten(tree)
proto_flat, _ = tu.tree_flatten(prototype_tree)
# - A function that conditionally maps over the tree leaves
def map_or_original(leaf: Any, select: bool, rng_key: Any) -> Any:
return jax.lax.cond(
select,
lambda _: map_fun(leaf, rng_key),
lambda _: leaf,
0,
)
# - Map function over leaves
_, *subkeys = jax.random.split(rng_key, len(tree_flat) + 1)
mapped = [map_or_original(*xs) for xs in zip(tree_flat, proto_flat, subkeys)]
# - Return tree
return tu.tree_unflatten(treedef, mapped)
[docs]def tree_map_with_rng(
tree: Tree,
map_fun: Callable[[Value, JaxRNGKey, Any], Value],
rng_key: JaxRNGKey,
*rest: Any,
) -> Tree:
"""
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf
This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then `map_fun` will be called for each leaf, with the signature ``map_fun(leaf_value, rng_key, *rest)``.
`rest` is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure as `tree`. See the documentation for `jax.tree_util.tree_map` for more information.
Args:
tree (Tree): A tree to work over
map_fun (Callable[[Value, JaxRNGKey, Any], Value]): A function to map over the tree. The function must have the signature ``map_fun(leaf_value, rng_key, *rest)``
rng_key (JaxRNGKey): An initial RNG key to split
*rest: A tuple of additional `tree`-shaped arguments that will be collectively mapped over `tree` when calling `map_fun`.
Returns:
Tree: The `tree`-shaped result of mapping `map_fun` over `tree`.
"""
# - Flatten the input tree
tree_flat, treedef = tu.tree_flatten(tree)
# - Split RNG keys for each tree leaf
_, *subkeys = jax.random.split(rng_key, len(tree_flat) + 1)
subkeys_tree = tu.tree_unflatten(treedef, subkeys)
# - Map function over the tree and return
return tu.tree_map(map_fun, tree, subkeys_tree, *rest)
```