Source code for utilities.jax_tree_utils

"""
Utility functions for working with trees.
"""

from warnings import warn

from typing import Tuple, Generator, Any, Callable, List

import copy

import jax
import jax.tree_util as tu

import functools

# - Set up some useful types
from rockpool.typehints import Tree, Leaf, Value, JaxRNGKey


[docs]def branches(tree: Tree, prefix: list = None) -> Generator[Tuple, None, None]: """ Generate all branches (paths from root node to any leaf) in a tree Args: tree (Tree): A nested dictionary tree structure prefix (list): The current branch prefix. Default: `None` (this is the root) Yields: Tuple[str]: """ # - Start from the root node if prefix is None: prefix = [] # - Loop over nodes for k, v in tree.items(): # - Is this a nested dict? if isinstance(v, dict): # - Get branches from subtree yield from branches(v, prefix + [k]) else: # - Yield this branch yield tuple(prefix + [k])
[docs]def get_nested(tree: Tree, branch: Tuple) -> None: """ Get a value from a tree branch, specifying a branch Args: tree (Tree): A nested dictionary tree structure branch (Tuple[str]): A branch: a tuple of indices to walk through the tree """ # - Start from the root node node = tree # - Iterate along the branch for key in branch[:-1]: node = node[key] # - Get the leaf value return node[branch[-1]]
[docs]def set_nested(tree: Tree, branch: Tuple, value: Any) -> None: """ Set a value in a tree branch, specifying a branch The leaf node must already exist in the tree. Args: tree (Tree): A nested dictionary tree structure branch (Tuple[str]): A branch: a tuple of indices to walk through the tree value (Any): The value to set at the tree leaf """ # - Start from the root node node = tree # - Iterate along the branch for key in branch[:-1]: node = node[key] # - Set the leaf value node[branch[-1]] = value
[docs]def set_matching(full_tree: Tree, target_tree: Tree, value: Any) -> None: """ Set the values in a full tree in-place, for branches that match a target tree Args: full_tree (Tree): A tree to search over. The values in this tree will be replaced with ``value`` target_tree (Tree): A tree that defines the target branches to set in ``full_tree``. Matching branches in ``full_tree`` will have their values replaced with ``value`` value (Any): The value to set in ``full_tree``. """ for branch in branches(target_tree): set_nested(full_tree, branch, value)
[docs]def make_prototype_tree(full_tree: Tree, target_tree: Tree) -> Tree: """ Construct a tree with boolean leaves, for nodes that match a target tree Make a prototype tree, indicating which nodes in a large tree should be selected for analysis or processing. This is done on the basis of a smaller "target" tree, which contains only the leaf nodes of interest. Examples: >>> target_tree = {'a': 0, 'b': {'b2': 0}} >>> full_tree = {'a': 1, 'b': {'b1': 2, 'b2': 3}, 'c': 4, 'd': 5} >>> make_prototype_tree(full_tree, target_tree) {'a': True, 'b': {'b1': False, 'b2': True}, 'c': False, 'd': False} Args: full_tree (Tree): A large tree to search through. target_tree (Tree): A tree with only few leaf nodes. These nodes will be identifed within the full tree. Returns: Tree: A nested tree with the same tree structure as `full_tree`, but with ``bool`` leaf nodes. Leaf nodes will be ``True`` for branches matching those specified in `target_tree`, and ``False`` otherwise. """ # - Make a copy of the input tree prototype = copy.deepcopy(full_tree) # - Get a list of target and full branches targets = list(branches(target_tree)) full_branches = list(branches(full_tree)) # - Sanity check the trees if len(full_branches) < len(targets): warn( SyntaxWarning( "make_prototype_tree: The `target` tree has more nodes than the `full` tree. Please check the order of arguments." ) ) # - Loop over all leaf branches in full tree for branch in full_branches: # - Is this a target branch? if branch in targets: # - Assign `True` in the prototype tree set_nested(prototype, branch, True) else: # - Assign `False` in the prototype tree set_nested(prototype, branch, False) # - Return the prototype tree return prototype
[docs]def tree_map_reduce_select( tree: Tree, protoype_tree: Tree, map_fun: Callable[[Leaf], Value], reduce_fun: Callable[[Value, Value], Value], null: Any = jax.numpy.nan, ) -> Value: """ Perform a map-reduce operation over a tree, but only matching selected nodes `map_fun()` is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes in `tree` when the corresponding leaf node in `prototype_tree` is ``True``. `map_fun()` can return whatever you like; however, you must ensure that the `null` value has the same shape as the value returned by `map_fun()` applied to a leaf node. `reduce_fun()` is a reduction function. This is called using `functools.reduce` over the mapped outputs of `map_fun()`. `reduce_fun(intermediate_value, map_value)` accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call to `reduce()`. Examples: >>> map_fun = np.nanmax >>> >>> def reduce_fun(x, y): >>> return np.nanmax(np.array([x, y])) >>> >>> tree_map_reduce_select({'a': 1}, {'a': 1}, map_fun, reduce_fun) 1 >>> def map_fun(leaf): >>> return {'mean': np.mean(leaf), >>> 'std': np.std(leaf), >>> } >>> >>> def reduce_fun(x, y): >>> return {'mean': np.mean(np.array([x['mean'], y['mean']])), >>> 'std': x['std'] + y['std'], >>> } >>> >>> null = map_fun(np.nan) >>> >>> tree_map_reduce_select({'a': 1}, {'a': 1}, map_fun, reduce_fun, null) {'mean': 1, 'std': NaN} Args: tree (Tree): A tree over which to operate prototype_tree (Tree): A prototype tree, with structure matching that of `tree`, but with ``bool`` leaves. Only leaf nodes in `tree` with ``True`` in the corresponding prototype tree node will be modified. map_fun (Callable[[Leaf], Value]): A function to perform on selected nodes in `tree`. `map_fun` has the signature `map_fun(leaf_node) -> value` reduce_fun (Callable[[Value, Value], Value]): A function that collects two values and returns the combination of the two values, to reduce over the mapped function. `reduce_fun` has the signature `reduce_fun(value, value) -> value`. null (Any): The "null" value to return from the map operation, if the leaf node is not selected in `prototype_tree`. Default: ``jax.numpy.nan`` Returns: Value: The result of the map-reduce operation over the tree """ tree_flat, _ = tu.tree_flatten(tree) proto_flat, _ = tu.tree_flatten(protoype_tree) def map_or_null(leaf: Any, select: bool) -> Any: return jax.lax.cond( select, lambda _: map_fun(leaf), lambda _: null, 0, ) # - Map function over leaves mapped = [map_or_null(*xs) for xs in zip(tree_flat, proto_flat)] # - Reduce function over leaves return functools.reduce(reduce_fun, mapped)
[docs]def tree_map_select( tree: Tree, prototype_tree: Tree, map_fun: Callable[[Leaf], Value] ) -> Tree: """ Map a scalar function over a tree, but only matching selected nodes Notes: `map_fun` must be a scalar function. This means that if the input is of shape ``(N, M)``, the output must also be of shape ``(N, M)``. Otherwise you will get an error. Args: tree (Tree): A tree over which to operate prototype_tree (Tree): A prototype tree, with structure matching that of `tree`, but with ``bool`` leaves. Only leaf nodes in `tree` with ``True`` in the corresponding prototype tree node will be modified. map_fun (Callable[[Leaf], Value]): A scalar function to perform on selected nodes in `tree` Returns: Tree: A tree with the same structure as `tree`, with leaf nodes replaced with the output of `map_fun()` for each leaf. """ # - Flatten both trees tree_flat, treedef = tu.tree_flatten(tree) proto_flat, _ = tu.tree_flatten(prototype_tree) # - A function that conditionally maps over the tree leaves def map_or_original(leaf: Any, select: bool) -> Any: return jax.lax.cond( select, lambda _: map_fun(leaf), lambda _: leaf, 0, ) # - Map function over leaves mapped = [map_or_original(*xs) for xs in zip(tree_flat, proto_flat)] # - Return tree return tu.tree_unflatten(treedef, mapped)
[docs]def tree_map_select_with_rng( tree: Tree, prototype_tree: Tree, map_fun: Callable[[Leaf, JaxRNGKey], Value], rng_key: JaxRNGKey, ) -> Tree: """ Map a scalar function over a tree, but only matching selected nodes. Includes jax-compatible random state Notes: `map_fun()` must be a scalar function. This means that if the input is of shape ``(N, M)``, the output must also be of shape ``(N, M)``. Otherwise you will get an error. The signature of `map_fun()` is `map_fun(leaf, rng_key) -> value`. Args: tree (PyTree): A tree over which to operate prototype_tree (PyTree): A prototype tree, with structure matching that of `tree`, but with ``bool`` leaves. Only leaf nodes in `tree` with ``True`` in the corresponding prototype tree node will be modified. map_fun (Callable[[Leaf, JaxRNGKey], Value]): A scalar function to perform on selected nodes in `tree`. The second argument is a jax pRNG key to use when generating random state. """ # - Flatten both trees tree_flat, treedef = tu.tree_flatten(tree) proto_flat, _ = tu.tree_flatten(prototype_tree) # - A function that conditionally maps over the tree leaves def map_or_original(leaf: Any, select: bool, rng_key: Any) -> Any: return jax.lax.cond( select, lambda _: map_fun(leaf, rng_key), lambda _: leaf, 0, ) # - Map function over leaves _, *subkeys = jax.random.split(rng_key, len(tree_flat) + 1) mapped = [map_or_original(*xs) for xs in zip(tree_flat, proto_flat, subkeys)] # - Return tree return tu.tree_unflatten(treedef, mapped)
[docs]def tree_map_with_rng( tree: Tree, map_fun: Callable[[Value, JaxRNGKey, Any], Value], rng_key: JaxRNGKey, *rest: Any, ) -> Tree: """ Perform a multimap over a tree, splitting and inserting an RNG key for each leaf This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then `map_fun` will be called for each leaf, with the signature ``map_fun(leaf_value, rng_key, *rest)``. `rest` is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure as `tree`. See the documentation for `jax.tree_util.tree_map` for more information. Args: tree (Tree): A tree to work over map_fun (Callable[[Value, JaxRNGKey, Any], Value]): A function to map over the tree. The function must have the signature ``map_fun(leaf_value, rng_key, *rest)`` rng_key (JaxRNGKey): An initial RNG key to split *rest: A tuple of additional `tree`-shaped arguments that will be collectively mapped over `tree` when calling `map_fun`. Returns: Tree: The `tree`-shaped result of mapping `map_fun` over `tree`. """ # - Flatten the input tree tree_flat, treedef = tu.tree_flatten(tree) # - Split RNG keys for each tree leaf _, *subkeys = jax.random.split(rng_key, len(tree_flat) + 1) subkeys_tree = tu.tree_unflatten(treedef, subkeys) # - Map function over the tree and return return tu.tree_map(map_fun, tree, subkeys_tree, *rest)
[docs]def tree_find(tree: Tree) -> List: """ Return the tree branches to tree nodes that evaluate to ``True`` Args: tree (Tree): A tree to examine Returns: list: A list of all tree branches, for which the corresponding tree leaf evaluate to ``True`` """ # - Get a list of all tree branches all_branches = list(branches(tree)) # - Return a list of branches to leaves that evaluate to `True` return [branch for branch in all_branches if get_nested(tree, branch)]