Source code for utilities.tree_utils

"""
Tree manipulation utilities with no external dependencies

This module provides methods for building and manipulating trees. 

A `Tree` is a nested dictionary. A `Leaf` is any other object. A `Node` is a non-leaf `Tree` node. A `TreeDef` is a nested dictionary with no data, only structure.
"""

import copy

from warnings import warn

from rockpool.typehints import Tree, Leaf, Node, TreeDef
from typing import Tuple, Any, Dict, Callable, Union, List, Optional, Generator

__all__ = [
    "branches",
    "get_nested",
    "set_nested",
    "set_matching",
    "set_matching_select",
    "make_prototype_tree",
    "tree_map",
    "tree_flatten",
    "tree_unflatten",
    "tree_update",
    "tree_find",
]


[docs]def branches(tree: Tree, prefix: list = None) -> Generator[Tuple, None, None]: """ Generate all branches (paths from root node to any leaf) in a tree Args: tree (Tree): A nested dictionary tree structure prefix (list): The current branch prefix. Default: `None` (this is the root) Yields: Tuple[str]: """ # - Start from the root node if prefix is None: prefix = [] # - Loop over nodes for k, v in tree.items(): # - Is this a nested dict? if isinstance(v, dict): # - Get branches from subtree yield from branches(v, prefix + [k]) else: # - Yield this branch yield tuple(prefix + [k])
[docs]def get_nested(tree: Tree, branch: Tuple) -> None: """ Get a value from a tree branch, specifying a branch Args: tree (Tree): A nested dictionary tree structure branch (Tuple[str]): A branch: a tuple of indices to walk through the tree """ # - Start from the root node node = tree # - Iterate along the branch for key in branch[:-1]: node = node[key] # - Get the leaf value return node[branch[-1]]
[docs]def set_nested(tree: Tree, branch: Tuple, value: Any, inplace: bool = False) -> Tree: """ Set a value in a tree branch, specifying a branch The leaf node must already exist in the tree. Args: tree (Tree): A nested dictionary tree structure branch (Tuple[str]): A branch: a tuple of indices to walk through the tree value (Any): The value to set at the tree leaf inplace (bool): If ``False`` (default), a copy of the tree will be returned. If ``True``, the operation will be performed in place, and the original tree will be returned Returns: Tree: The modified tree """ if not inplace: tree = copy.deepcopy(tree) # - Start from the root node node = tree # - Iterate along the branch for key in branch[:-1]: node = node[key] # - Set the leaf value node[branch[-1]] = value return tree
[docs]def set_matching( full_tree: Tree, target_tree: Tree, value: Any, inplace: bool = False ) -> Tree: """ Set the values in a full tree, for branches that match a target tree Args: full_tree (Tree): A tree to search over. The values in this tree will be replaced with ``value`` target_tree (Tree): A tree that defines the target branches to set in ``full_tree``. Matching branches in ``full_tree`` will have their values replaced with ``value`` value (Any): The value to set in ``full_tree``. inplace (bool): If ``False`` (default), a copy of the tree will be returned. If ``True``, the operation will be performed in place, and the original tree will be returned Returns: Tree: The modified tree """ if not inplace: full_tree = copy.deepcopy(full_tree) for branch in branches(target_tree): set_nested(full_tree, branch, value, inplace=True) return full_tree
[docs]def set_matching_select( full_tree: Tree, target_tree: Tree, value: Any, inplace: bool = False ) -> Tree: """ Set the values in a full tree, for branches that match a target tree, if the target tree leaf nodes evaluate to ``True`` Args: full_tree (Tree): A tree to search over. The values in this tree will be replaced with ``value`` target_tree (Tree): A tree that defines the target branches to set in ``full_tree``. Matching branches in ``full_tree`` will have their values replaced with ``value``, if the leaf node in ``target_tree` evaluates to ``True`` value (Any): The value to set in ``full_tree``. inplace (bool): If ``False`` (default), a copy of the tree will be returned. If ``True``, the operation will be performed in place, and the original tree will be returned Returns: Tree: The modified tree """ if not inplace: full_tree = copy.deepcopy(full_tree) for branch in branches(target_tree): if get_nested(target_tree, branch): set_nested(full_tree, branch, value, inplace=True) return full_tree
[docs]def make_prototype_tree(full_tree: Tree, target_tree: Tree) -> Tree: """ Construct a tree with boolean leaves, for nodes that match a target tree Make a prototype tree, indicating which nodes in a large tree should be selected for analysis or processing. This is done on the basis of a smaller "target" tree, which contains only the leaf nodes of interest. Examples: >>> target_tree = {'a': 0, 'b': {'b2': 0}} >>> full_tree = {'a': 1, 'b': {'b1': 2, 'b2': 3}, 'c': 4, 'd': 5} >>> make_prototype_tree(full_tree, target_tree) {'a': True, 'b': {'b1': False, 'b2': True}, 'c': False, 'd': False} Args: full_tree (Tree): A large tree to search through. target_tree (Tree): A tree with only few leaf nodes. These nodes will be identifed within the full tree. Returns: Tree: A nested tree with the same tree structure as `full_tree`, but with ``bool`` leaf nodes. Leaf nodes will be ``True`` for branches matching those specified in `target_tree`, and ``False`` otherwise. """ # - Make a copy of the input tree prototype = copy.deepcopy(full_tree) # - Get a list of target and full branches targets = list(branches(target_tree)) full_branches = list(branches(full_tree)) # - Sanity check the trees if len(full_branches) < len(targets): warn( SyntaxWarning( "make_prototype_tree: The `target` tree has more nodes than the `full` tree. Please check the order of arguments." ) ) # - Loop over all leaf branches in full tree for branch in full_branches: # - Is this a target branch? if branch in targets: # - Assign `True` in the prototype tree set_nested(prototype, branch, True, inplace=True) else: # - Assign `False` in the prototype tree set_nested(prototype, branch, False, inplace=True) # - Return the prototype tree return prototype
[docs]def tree_map(tree: Tree, f: Callable) -> Tree: """ Map a function over the leaves of a tree This function performs a recurdive depth-first traversal of the tree. Args: tree (Tree): A tree to traverse f (Callable): A function which is called on each leaf of the tree. Must have the signature ``Callable[Leaf] -> Any`` Returns: Tree: A tree with the same structure as ``tree``, with leaf nodes replaced with the result of calling ``f`` on each leaf. """ # - Initialise a new root root = {} # - Loop over nodes for k, v in tree.items(): # - Is this a nested dict? if isinstance(v, dict): # - Recurse root[k] = tree_map(v, f) else: # - Map function over this value root[k] = f(v) return root
[docs]def tree_flatten( tree: Tree, leaves: Union[List[Any], None] = None ) -> Tuple[List[Any], TreeDef]: """ Flatten a tree into a linear list of leaf nodes and a tree description This function operates similar to ``jax.tree_utils.tree_flatten``, but is *not* directly compatible. A `Tree` ``tree`` will be serialised into a simple list of leaf nodes, which can then be conveniently manipulated. A `TreeDef` will also be returned, which is a nested dictionary with the same structure as ``tree``. The function :py:func:`.tree_unflatten` performs the reverse operation. Args: tree (Tree): A tree to flatten leaves (Optional[List[Any]]): Used recursively. Should be left as ``None`` by the user. Returns: Tuple[List[Any], TreeDef]: A list of leaf nodes from the flattened tree, and a tree definition. """ # - Initialise leaves if starting from the root leaves = [] if leaves is None else leaves # - Initialise a new treedef root treedef = {} # - Loop over nodes for k, v in tree.items(): # - Is this a nested dict? if isinstance(v, dict): # - Recurse and build the treedef _, treedef[k] = tree_flatten(v, leaves) else: # - Record this leaf and build the treedef leaves.append(v) treedef[k] = None return leaves, treedef
[docs]def tree_unflatten( treedef: TreeDef, leaves: List, leaves_tail: Optional[List[Any]] = None ) -> Tree: """ Build a tree from a flat list of leaves, plus a tree definition This function takes a flattened tree representation, as built by :py:func:`.tree_flatten`, and reconstructs a matching `Tree` structure. Args: treedef (TreeDef): A tree definition as returned by :py:func:`.tree_flatten` leaves (List[Any]): A list of leaf nodes to use in constructing the tree leaves_tail (Optional[List[Any]]): Used recursively. Should be left as ``None`` by the end user Returns: Tree: The reconstructed tree, with leaves taken from ``leaves`` """ tree = copy.deepcopy(treedef) leaves_tail = copy.deepcopy(leaves) if leaves_tail is None else leaves_tail # - Loop over nodes for k, v in tree.items(): # - Is this a nested dict? if isinstance(v, dict): # - Recurse tree[k] = tree_unflatten(treedef[k], leaves, leaves_tail) else: tree[k] = leaves_tail.pop(0) return tree
[docs]def tree_update(target: Tree, additional: Tree, inplace: bool = False) -> Tree: """ Perform a recursive update of a tree to insert or replace nodes from a second tree Requires a ``target`` `Tree` and a source `Tree` ``additional``, which will provide the source data to update in ``target``. Both ``target`` and ``additional`` will be traversed depth-first simultaneously. `Leaf` nodes that exist in ``target`` but not in ``additional`` will not be modified. `Leaf` nodes that exist in ``additional`` but not in ``target`` will be inserted into ``target`` at the corresponding location. `Leaf` nodes that exist in both trees will have their data updated from ``additional`` to ``target``, using the python :py:func:`update` function. Args: target (Tree): The tree to update. additional (Tree): The source tree to insert / replace nodes from, into ``target``. Will not be modified. inplace (bool): If ``False`` (default), a copy of the tree will be returned. If ``True``, the operation will be performed in place, and the original tree will be returned. Returns: Tree: The modified target tree """ if not inplace: target = copy.deepcopy(target) for k, v in additional.items(): if isinstance(v, dict) and k in target: tree_update(target[k], v, inplace=True) else: target.update({k: v}) return target
[docs]def tree_find(tree: Tree) -> Generator[Tuple, None, None]: """ Generate the tree branches to tree nodes that evaluate to ``True`` Args: tree (Tree): A tree to examine Returns: list: A list of all tree branches, for which the corresponding tree leaf evaluate to ``True`` """ # - Loop over tree branches for branch in branches(tree): # - Yield branches to leaves that evaluate to `True` if get_nested(tree, branch): yield branch