Simplified PyTree Utilities

Utilities for working with PyTrees.

The optree.pytree namespace contains aliases of optree.tree_* utilities.

>>> import optree.pytree as pytree
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = pytree.flatten(tree)
>>> leaves, treespec
(
    [1, 2, 3, 4, 5],
    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> tree == pytree.unflatten(treespec, leaves)
True

Added in version 0.14.1.

Re-export PyTree Utilities as a Sub-module

reexport(*, namespace[, module])

Re-export a pytree utility module with the given namespace as default.

optree.pytree.reexport(*, namespace, module=None)[source]

Re-export a pytree utility module with the given namespace as default.

>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))

This function is useful for downstream libraries that want to re-export the pytree utilities with their own namespace:

# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')
del optree

# foo/bar.py
from foo import pytree

@pytree.dataclasses.dataclass
class Bar:
    a: int
    b: float

# User code
In [1]: import foo

In [2]: foo.pytree.flatten({'a': 1, 'b': 2, 'c': foo.bar.Bar(3, 4.0)}))
Out[2]:
(
    [1, 2, 3, 4.0],
    PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo')
)

In [3]: foo.pytree.functools.reduce(lambda x, y: x * y, {'a': 1, 'b': 2, 'c': foo.bar.Bar(3, 4.0)}))
Out[3]: 24.0

Added in version 0.16.0.

Parameters:
  • namespace (str) – The namespace to use in the re-exported module.

  • module (str, optional) – The name of the re-exported module. If not provided, defaults to <caller_module>.pytree. The caller module is determined by inspecting the stack frame.

Return type:

ModuleType

Returns:

The re-exported module.

Tree Operations

Check section Tree Manipulation Functions and Tree Reduce Functions for more detailed documentation.

dict_insertion_ordered(mode, /, *, namespace)

Context manager to temporarily set the dictionary sorting mode.

flatten(tree, /[, is_leaf, none_is_leaf, ...])

Flatten a pytree.

flatten_with_path(tree, /[, is_leaf, ...])

Flatten a pytree and additionally record the paths.

flatten_with_accessor(tree, /[, is_leaf, ...])

Flatten a pytree and additionally record the accessors.

unflatten(treespec, leaves)

Reconstruct a pytree from the treespec and the leaves.

iter(tree, /[, is_leaf, none_is_leaf, namespace])

Get an iterator over the leaves of a pytree.

leaves(tree, /[, is_leaf, none_is_leaf, ...])

Get the leaves of a pytree.

structure(tree, /[, is_leaf, none_is_leaf, ...])

Get the treespec for a pytree.

paths(tree, /[, is_leaf, none_is_leaf, ...])

Get the path entries to the leaves of a pytree.

accessors(tree, /[, is_leaf, none_is_leaf, ...])

Get the accessors to the leaves of a pytree.

is_leaf(tree, /[, is_leaf, none_is_leaf, ...])

Test whether the given object is a leaf node.

map(func, tree, /, *rests[, is_leaf, ...])

Map a multi-input function over pytree args to produce a new pytree.

map_(func, tree, /, *rests[, is_leaf, ...])

Like tree_map(), but do an inplace call on each leaf and return the original tree.

map_with_path(func, tree, /, *rests[, ...])

Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.

map_with_path_(func, tree, /, *rests[, ...])

Like tree_map_with_path(), but do an inplace call on each leaf and return the original tree.

map_with_accessor(func, tree, /, *rests[, ...])

Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree.

map_with_accessor_(func, tree, /, *rests[, ...])

Like tree_map_with_accessor(), but do an inplace call on each leaf and return the original tree.

replace_nones(sentinel, tree, /[, namespace])

Replace None in tree with sentinel.

partition(predicate, tree, /[, is_leaf, ...])

Partition a tree into the left and right part by the given predicate function.

transpose(outer_treespec, inner_treespec, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

transpose_map(func, tree, /, *rests[, ...])

Map a multi-input function over pytree args to produce a new pytree with transposed structure.

transpose_map_with_path(func, tree, /, *rests)

Map a multi-input function over pytree args as well as the tree paths to produce a new pytree with transposed structure.

transpose_map_with_accessor(func, tree, /, ...)

Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree with transposed structure.

broadcast_prefix(prefix_tree, full_tree, /)

Return a pytree of same structure of full_tree with broadcasted subtrees in prefix_tree.

broadcast_common(tree, other_tree, /[, ...])

Return two pytrees of common suffix structure of tree and other_tree with broadcasted subtrees.

broadcast_map(func, tree, /, *rests[, ...])

Map a multi-input function over pytree args to produce a new pytree.

broadcast_map_with_path(func, tree, /, *rests)

Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.

broadcast_map_with_accessor(func, tree, /, ...)

Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree.

reduce(func, tree, /[, initial, is_leaf, ...])

Traversal through a pytree and reduce the leaves in left-to-right depth-first order.

sum(tree, /[, start, is_leaf, none_is_leaf, ...])

Sum start and leaf values in tree in left-to-right depth-first order and return the total.

max(tree, /, *[, default, key, is_leaf, ...])

Return the maximum leaf value in tree.

min(tree, /, *[, default, key, is_leaf, ...])

Return the minimum leaf value in tree.

all(tree, /, *[, is_leaf, none_is_leaf, ...])

Test whether all leaves in tree are true (or if tree is empty).

any(tree, /, *[, is_leaf, none_is_leaf, ...])

Test whether all leaves in tree are true (or False if tree is empty).

flatten_one_level(tree, /[, is_leaf, ...])

Flatten the pytree one level, returning a 4-tuple of children, metadata, path entries, and an unflatten function.

Node Registration

Check section PyTree Node Registration for more detailed documentation.

register_node(cls, /, flatten_func, ...[, ...])

Extend the set of types that are considered internal nodes in pytrees.

register_node_class([cls, path_entry_type, ...])

Extend the set of types that are considered internal nodes in pytrees.

unregister_node(cls, /, *, namespace)

Remove a type from the pytree node registry.