Module utilities.jax_tree_utils
Utility functions for working with trees.
Functions overview
|
Perform a map-reduce operation over a tree, but only matching selected nodes |
|
Map a scalar function over a tree, but only matching selected nodes |
|
Map a scalar function over a tree, but only matching selected nodes. |
|
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf |
Functions
- utilities.jax_tree_utils.tree_map_reduce_select(tree: Iterable | MutableMapping | Mapping, protoype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any], Any], reduce_fun: Callable[[Any, Any], Any], null: Any = nan) Any[source]
Perform a map-reduce operation over a tree, but only matching selected nodes
map_fun()is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes intreewhen the corresponding leaf node inprototype_treeisTrue.map_fun()can return whatever you like; however, you must ensure that thenullvalue has the same shape as the value returned bymap_fun()applied to a leaf node.reduce_fun()is a reduction function. This is called usingfunctools.reduceover the mapped outputs ofmap_fun().reduce_fun(intermediate_value, map_value)accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call toreduce().Examples
>>> map_fun = np.nanmax >>> >>> def reduce_fun(x, y): >>> return np.nanmax(np.array([x, y])) >>> >>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun) 1
>>> def map_fun(leaf): >>> return {'mean': np.mean(leaf), >>> 'std': np.std(leaf), >>> } >>> >>> def reduce_fun(x, y): >>> return {'mean': np.mean(np.array([x['mean'], y['mean']])), >>> 'std': x['std'] + y['std'], >>> } >>> >>> null = map_fun(np.nan) >>> >>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun, null) {'mean': 1, 'std': NaN}
- Parameters:
tree (Tree) – A tree over which to operate
prototype_tree (Tree) – A prototype tree, with structure matching that of
tree, but withboolleaves. Only leaf nodes intreewithTruein the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) – A function to perform on selected nodes in
tree.map_funhas the signaturemap_fun(leaf_node) -> valuereduce_fun (Callable[[Value, Value], Value]) – A function that collects two values and returns the combination of the two values, to reduce over the mapped function.
reduce_funhas the signaturereduce_fun(value, value) -> value.null (Any) – The “null” value to return from the map operation, if the leaf node is not selected in
prototype_tree. Default:jax.numpy.nan
- Returns:
The result of the map-reduce operation over the tree
- Return type:
Value
- utilities.jax_tree_utils.tree_map_select(tree: Iterable | MutableMapping | Mapping, prototype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any], Any]) Iterable | MutableMapping | Mapping[source]
Map a scalar function over a tree, but only matching selected nodes
Notes
map_funmust be a scalar function. This means that if the input is of shape(N, M), the output must also be of shape(N, M). Otherwise you will get an error.- Parameters:
tree (Tree) – A tree over which to operate
prototype_tree (Tree) – A prototype tree, with structure matching that of
tree, but withboolleaves. Only leaf nodes intreewithTruein the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) – A scalar function to perform on selected nodes in
tree
- Returns:
A tree with the same structure as
tree, with leaf nodes replaced with the output ofmap_fun()for each leaf.- Return type:
Tree
- utilities.jax_tree_utils.tree_map_select_with_rng(tree: Iterable | MutableMapping | Mapping, prototype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any, Any], Any], rng_key: Any) Iterable | MutableMapping | Mapping[source]
Map a scalar function over a tree, but only matching selected nodes. Includes jax-compatible random state
Notes
map_fun()must be a scalar function. This means that if the input is of shape(N, M), the output must also be of shape(N, M). Otherwise you will get an error. The signature ofmap_fun()ismap_fun(leaf, rng_key) -> value.- Parameters:
tree (PyTree) – A tree over which to operate
prototype_tree (PyTree) – A prototype tree, with structure matching that of
tree, but withboolleaves. Only leaf nodes intreewithTruein the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf, JaxRNGKey], Value]) – A scalar function to perform on selected nodes in
tree. The second argument is a jax pRNG key to use when generating random state.
- utilities.jax_tree_utils.tree_map_with_rng(tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any, Any, Any], Any], rng_key: Any, *rest: Any) Iterable | MutableMapping | Mapping[source]
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf
This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then
map_funwill be called for each leaf, with the signaturemap_fun(leaf_value, rng_key, *rest).restis an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure astree. See the documentation forjax.tree_util.tree_mapfor more information.- Parameters:
tree (Tree) – A tree to work over
map_fun (Callable[[Value, JaxRNGKey, Any], Value]) – A function to map over the tree. The function must have the signature
map_fun(leaf_value, rng_key, *rest)rng_key (JaxRNGKey) – An initial RNG key to split
*rest – A tuple of additional
tree-shaped arguments that will be collectively mapped overtreewhen callingmap_fun.
- Returns:
The
tree-shaped result of mappingmap_funovertree.- Return type:
Tree