Module utilities.jax_tree_utils
Utility functions for working with trees.
Functions overview
|
Perform a map-reduce operation over a tree, but only matching selected nodes |
|
Map a scalar function over a tree, but only matching selected nodes |
|
Map a scalar function over a tree, but only matching selected nodes. |
|
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf |
Functions
- utilities.jax_tree_utils.tree_map_reduce_select(tree: Iterable | MutableMapping | Mapping, protoype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any], Any], reduce_fun: Callable[[Any, Any], Any], null: Any = nan) Any [source]
Perform a map-reduce operation over a tree, but only matching selected nodes
map_fun()
is a function that will be mapped over the leaves of a tree. It will only be applied to leaf nodes intree
when the corresponding leaf node inprototype_tree
isTrue
.map_fun()
can return whatever you like; however, you must ensure that thenull
value has the same shape as the value returned bymap_fun()
applied to a leaf node.reduce_fun()
is a reduction function. This is called usingfunctools.reduce
over the mapped outputs ofmap_fun()
.reduce_fun(intermediate_value, map_value)
accepts two mapped leaf values, and should combine them into a single result. The first value will be retained as the intermediate computaiton, and passed to the next call toreduce()
.Examples
>>> map_fun = np.nanmax >>> >>> def reduce_fun(x, y): >>> return np.nanmax(np.array([x, y])) >>> >>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun) 1
>>> def map_fun(leaf): >>> return {'mean': np.mean(leaf), >>> 'std': np.std(leaf), >>> } >>> >>> def reduce_fun(x, y): >>> return {'mean': np.mean(np.array([x['mean'], y['mean']])), >>> 'std': x['std'] + y['std'], >>> } >>> >>> null = map_fun(np.nan) >>> >>> tree_map_reduce_select({'a': 1}, {'a': True}, map_fun, reduce_fun, null) {'mean': 1, 'std': NaN}
- Parameters:
tree (Tree) – A tree over which to operate
prototype_tree (Tree) – A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) – A function to perform on selected nodes in
tree
.map_fun
has the signaturemap_fun(leaf_node) -> value
reduce_fun (Callable[[Value, Value], Value]) – A function that collects two values and returns the combination of the two values, to reduce over the mapped function.
reduce_fun
has the signaturereduce_fun(value, value) -> value
.null (Any) – The “null” value to return from the map operation, if the leaf node is not selected in
prototype_tree
. Default:jax.numpy.nan
- Returns:
The result of the map-reduce operation over the tree
- Return type:
Value
- utilities.jax_tree_utils.tree_map_select(tree: Iterable | MutableMapping | Mapping, prototype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any], Any]) Iterable | MutableMapping | Mapping [source]
Map a scalar function over a tree, but only matching selected nodes
Notes
map_fun
must be a scalar function. This means that if the input is of shape(N, M)
, the output must also be of shape(N, M)
. Otherwise you will get an error.- Parameters:
tree (Tree) – A tree over which to operate
prototype_tree (Tree) – A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf], Value]) – A scalar function to perform on selected nodes in
tree
- Returns:
A tree with the same structure as
tree
, with leaf nodes replaced with the output ofmap_fun()
for each leaf.- Return type:
Tree
- utilities.jax_tree_utils.tree_map_select_with_rng(tree: Iterable | MutableMapping | Mapping, prototype_tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any, Any], Any], rng_key: Any) Iterable | MutableMapping | Mapping [source]
Map a scalar function over a tree, but only matching selected nodes. Includes jax-compatible random state
Notes
map_fun()
must be a scalar function. This means that if the input is of shape(N, M)
, the output must also be of shape(N, M)
. Otherwise you will get an error. The signature ofmap_fun()
ismap_fun(leaf, rng_key) -> value
.- Parameters:
tree (PyTree) – A tree over which to operate
prototype_tree (PyTree) – A prototype tree, with structure matching that of
tree
, but withbool
leaves. Only leaf nodes intree
withTrue
in the corresponding prototype tree node will be modified.map_fun (Callable[[Leaf, JaxRNGKey], Value]) – A scalar function to perform on selected nodes in
tree
. The second argument is a jax pRNG key to use when generating random state.
- utilities.jax_tree_utils.tree_map_with_rng(tree: Iterable | MutableMapping | Mapping, map_fun: Callable[[Any, Any, Any], Any], rng_key: Any, *rest: Any) Iterable | MutableMapping | Mapping [source]
Perform a multimap over a tree, splitting and inserting an RNG key for each leaf
This utility maps a function over the leaves of a tree, when the function requires an RNG key to operate. The utility will automatically split the RNG key to generate a new key for each leaf. Then
map_fun
will be called for each leaf, with the signaturemap_fun(leaf_value, rng_key, *rest)
.rest
is an optional further series of arguments to map over the tree, such that each additional argument must have the same tree structure astree
. See the documentation forjax.tree_util.tree_map
for more information.- Parameters:
tree (Tree) – A tree to work over
map_fun (Callable[[Value, JaxRNGKey, Any], Value]) – A function to map over the tree. The function must have the signature
map_fun(leaf_value, rng_key, *rest)
rng_key (JaxRNGKey) – An initial RNG key to split
*rest – A tuple of additional
tree
-shaped arguments that will be collectively mapped overtree
when callingmap_fun
.
- Returns:
The
tree
-shaped result of mappingmap_fun
overtree
.- Return type:
Tree