API References
OpTree: Optimized PyTree Utilities.
- optree.MAX_RECURSION_DEPTH: int = 1000
Maximum recursion depth for pytree traversal.
This limit prevents infinite recursion from causing an overflow of the C stack and crashing Python.
- optree.tree_flatten(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Flatten a pytree.
See also
tree_flatten_with_path()andtree_unflatten().The flattening order (i.e., the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten(tree) ( [1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) ) >>> tree_flatten(tree, none_is_leaf=True) ( [1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) ) >>> tree_flatten(1) ([1], PyTreeSpec(*)) >>> tree_flatten(None) ([], PyTreeSpec(None)) >>> tree_flatten(None, none_is_leaf=True) ([None], PyTreeSpec(*, NoneIsLeaf))
For unordered dictionaries,
dictandcollections.defaultdict, the order is dependent on the sorted keys in the dictionary. Please usecollections.OrderedDictif you want to keep the keys in the insertion order.>>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) >>> tree_flatten(tree) ( [2, 3, 4, 1, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})) ) >>> tree_flatten(tree, none_is_leaf=True) ( [2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf) )
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
tuple[list[TypeVar(T)],PyTreeSpec]- Returns:
A pair
(leaves, treespec)where the first element is a list of leaf values and the second element is a treespec representing the structure of the pytree.
- optree.tree_flatten_with_path(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Flatten a pytree and additionally record the paths.
See also
tree_flatten(),tree_paths(), andtreespec_paths().The flattening order (i.e., the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten_with_path(tree) ( [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)], [1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) ) >>> tree_flatten_with_path(tree, none_is_leaf=True) ( [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)], [1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) ) >>> tree_flatten_with_path(1) ([()], [1], PyTreeSpec(*)) >>> tree_flatten_with_path(None) ([], [], PyTreeSpec(None)) >>> tree_flatten_with_path(None, none_is_leaf=True) ([()], [None], PyTreeSpec(*, NoneIsLeaf))
For unordered dictionaries,
dictandcollections.defaultdict, the order is dependent on the sorted keys in the dictionary. Please usecollections.OrderedDictif you want to keep the keys in the insertion order.>>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) >>> tree_flatten_with_path(tree) ( [('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('d',)], [2, 3, 4, 1, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})) ) >>> tree_flatten_with_path(tree, none_is_leaf=True) ( [('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)], [2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf) )
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A triple
(paths, leaves, treespec). The first element is a list of paths to the leaf values, where each path is a tuple of the index or keys. The second element is a list of leaf values and the last element is a treespec representing the structure of the pytree.
- optree.tree_flatten_with_accessor(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Flatten a pytree and additionally record the accessors.
See also
tree_flatten(),tree_accessors(), andtreespec_accessors().The flattening order (i.e., the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten_with_accessor(tree) ( [ PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) ], [1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) ) >>> tree_flatten_with_accessor(tree, none_is_leaf=True) ( [ PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), PyTreeAccessor(*['c'], (MappingEntry(key='c', type=<class 'dict'>),)), PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) ], [1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) ) >>> tree_flatten_with_accessor(1) ([PyTreeAccessor(*, ())], [1], PyTreeSpec(*)) >>> tree_flatten_with_accessor(None) ([], [], PyTreeSpec(None)) >>> tree_flatten_with_accessor(None, none_is_leaf=True) ([PyTreeAccessor(*, ())], [None], PyTreeSpec(*, NoneIsLeaf))
For unordered dictionaries,
dictandcollections.defaultdict, the order is dependent on the sorted keys in the dictionary. Please usecollections.OrderedDictif you want to keep the keys in the insertion order.>>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) >>> tree_flatten_with_accessor(tree) ( [ PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=0, type=<class 'tuple'>))), PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>),)), PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'collections.OrderedDict'>),)) ], [2, 3, 4, 1, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})) ) >>> tree_flatten_with_accessor(tree, none_is_leaf=True) ( [ PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=0, type=<class 'tuple'>))), PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>),)), PyTreeAccessor(*['c'], (MappingEntry(key='c', type=<class 'collections.OrderedDict'>),)), PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'collections.OrderedDict'>),)) ], [2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf) )
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
tuple[list[PyTreeAccessor],list[TypeVar(T)],PyTreeSpec]- Returns:
A triple
(accessors, leaves, treespec). The first element is a list of accessors to the leaf values. The second element is a list of leaf values and the last element is a treespec representing the structure of the pytree.
- optree.tree_unflatten(treespec, leaves)[source]
Reconstruct a pytree from the treespec and the leaves.
The inverse of
tree_flatten().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> leaves, treespec = tree_flatten(tree) >>> tree == tree_unflatten(treespec, leaves) True
- Parameters:
treespec (PyTreeSpec) – The treespec to reconstruct.
leaves (iterable) – The list of leaves to use for reconstruction. The list must match the number of leaves of the treespec.
- Return type:
- Returns:
The reconstructed pytree, containing the
leavesplaced in the structure described bytreespec.
- optree.tree_iter(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Get an iterator over the leaves of a pytree.
See also
tree_flatten()andtree_leaves().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> list(tree_iter(tree)) [1, 2, 3, 4, 5] >>> list(tree_iter(tree, none_is_leaf=True)) [1, 2, 3, 4, None, 5] >>> list(tree_iter(1)) [1] >>> list(tree_iter(None)) [] >>> list(tree_iter(None, none_is_leaf=True)) [None]
- Parameters:
tree (pytree) – A pytree to iterate over.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
An iterator over the leaf values.
- optree.tree_leaves(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Get the leaves of a pytree.
See also
tree_flatten()andtree_iter().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_leaves(tree) [1, 2, 3, 4, 5] >>> tree_leaves(tree, none_is_leaf=True) [1, 2, 3, 4, None, 5] >>> tree_leaves(1) [1] >>> tree_leaves(None) [] >>> tree_leaves(None, none_is_leaf=True) [None]
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A list of leaf values.
- optree.tree_structure(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Get the treespec for a pytree.
See also
tree_flatten().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_structure(tree) PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) >>> tree_structure(tree, none_is_leaf=True) PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) >>> tree_structure(1) PyTreeSpec(*) >>> tree_structure(None) PyTreeSpec(None) >>> tree_structure(None, none_is_leaf=True) PyTreeSpec(*, NoneIsLeaf)
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec object representing the structure of the pytree.
- optree.tree_paths(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Get the path entries to the leaves of a pytree.
See also
tree_flatten(),tree_flatten_with_path(), andtreespec_paths().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_paths(tree) [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)] >>> tree_paths(tree, none_is_leaf=True) [('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)] >>> tree_paths(1) [()] >>> tree_paths(None) [] >>> tree_paths(None, none_is_leaf=True) [()]
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A list of paths to the leaf values, where each path is a tuple of the index or keys.
- optree.tree_accessors(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Get the accessors to the leaves of a pytree.
See also
tree_flatten(),tree_flatten_with_accessor(), andtreespec_accessors().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_accessors(tree) [ PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) ] >>> tree_accessors(tree, none_is_leaf=True) [ PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), PyTreeAccessor(*['c'], (MappingEntry(key='c', type=<class 'dict'>),)), PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) ] >>> tree_accessors(1) [PyTreeAccessor(*, ())] >>> tree_accessors(None) [] >>> tree_accessors(None, none_is_leaf=True) [PyTreeAccessor(*, ())]
- Parameters:
tree (pytree) – A pytree to flatten.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
list[PyTreeAccessor]- Returns:
A list of accessors to the leaf values.
- optree.tree_is_leaf(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Test whether the given object is a leaf node.
See also
tree_flatten(),tree_leaves(), andall_leaves().>>> tree_is_leaf(1) True >>> tree_is_leaf(None) False >>> tree_is_leaf(None, none_is_leaf=True) True >>> tree_is_leaf({'a': 1, 'b': (2, 3)}) False
- Parameters:
tree (pytree) – A pytree to check if it is a leaf node.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than a leaf. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A boolean indicating if the given object is a leaf node.
- optree.all_leaves(iterable, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Test whether all elements in the given iterable are leaves.
See also
tree_flatten(),tree_leaves(), andtree_is_leaf().>>> tree = {'a': [1, 2, 3]} >>> all_leaves(tree_leaves(tree)) True >>> all_leaves([tree]) False >>> all_leaves([1, 2, None, 3]) False >>> all_leaves([1, 2, None, 3], none_is_leaf=True) True
Note that this function iterates and checks the elements in the input iterable object, which uses the
iter()function. For dictionaries,iter(d)for a dictionaryditerates the keys of the dictionary, not the values.>>> list({'a': 1, 'b': (2, 3)}) ['a', 'b'] >>> all_leaves({'a': 1, 'b': (2, 3)}) True
This function is useful in advanced cases. For example, if a library allows arbitrary map operations on a flat list of leaves it may want to check if the result is still a flat list of leaves.
- Parameters:
iterable (iterable) – An iterable of objects.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than a leaf. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A boolean indicating if all elements in the input iterable are leaves.
- optree.tree_map(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args to produce a new pytree.
See also
tree_map_(),tree_map_with_path(),tree_map_with_path_(), andtree_broadcast_map().>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) {'x': 8, 'y': (43, 65)} >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (43, 65), 'z': None} >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) {'x': False, 'y': (False, False), 'z': None} >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True) {'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, the structure of the tree is taken from the first input; subsequent inputs need only have
treeas a prefix:>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) [[5, 7, 9], [6, 1, 2]]
- Parameters:
func (callable) – A function that takes
1 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees.tree (pytree) – A pytree to be mapped over, with each leaf providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new pytree with the same structure as
treebut with the value at each leaf given byfunc(x, *xs)wherexis the value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_map_(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Like
tree_map(), but do an inplace call on each leaf and return the original tree.See also
tree_map(),tree_map_with_path(), andtree_map_with_path_().- Parameters:
func (callable) – A function that takes
1 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees.tree (pytree) – A pytree to be mapped over, with each leaf providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
The original
treewith the values at each leaf modified by the side effect of functionfunc(x, *xs)(not the return value) wherexis the value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_map_with_path(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
See also
tree_map(),tree_map_(), andtree_map_with_path_().>>> tree_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)}) {'x': (1, 7), 'y': ((2, 42), (2, 64))} >>> tree_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (44, 66), 'z': None} >>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}) {'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}} >>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True) {'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}}
- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra paths.tree (pytree) – A pytree to be mapped over, with each leaf providing the second positional argument and the corresponding path providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new pytree with the same structure as
treebut with the value at each leaf given byfunc(p, x, *xs)where(p, x)are the path and value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_map_with_path_(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Like
tree_map_with_path(), but do an inplace call on each leaf and return the original tree.See also
tree_map(),tree_map_(), andtree_map_with_path().- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra paths.tree (pytree) – A pytree to be mapped over, with each leaf providing the second positional argument and the corresponding path providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
The original
treewith the values at each leaf modified by the side effect of functionfunc(p, x, *xs)(not the return value) where(p, x)are the path and value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_map_with_accessor(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree.
See also
tree_map(),tree_map_(), andtree_map_with_accessor_().>>> tree_map_with_accessor(lambda a, x: f'{a.codify("tree")} = {x!r}', {'x': 7, 'y': (42, 64)}) {'x': "tree['x'] = 7", 'y': ("tree['y'][0] = 42", "tree['y'][1] = 64")} >>> tree_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (44, 66), 'z': None} >>> tree_map_with_accessor( ... lambda a, x: a, ... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, ... ) { 'x': PyTreeAccessor(*['x'], ...), 'y': ( PyTreeAccessor(*['y'][0], ...), PyTreeAccessor(*['y'][1], ...) ), 'z': {1.5: None} } >>> tree_map_with_accessor( ... lambda a, x: a, ... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, ... none_is_leaf=True, ... ) { 'x': PyTreeAccessor(*['x'], ...), 'y': ( PyTreeAccessor(*['y'][0], ...), PyTreeAccessor(*['y'][1], ...) ), 'z': { 1.5: PyTreeAccessor(*['z'][1.5], ...) } }
- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra accessors.tree (pytree) – A pytree to be mapped over, with each leaf providing the second positional argument and the corresponding accessor providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new pytree with the same structure as
treebut with the value at each leaf given byfunc(a, x, *xs)where(a, x)are the accessor and value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_map_with_accessor_(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Like
tree_map_with_accessor(), but do an inplace call on each leaf and return the original tree.See also
tree_map(),tree_map_(), andtree_map_with_accessor().- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra accessors.tree (pytree) – A pytree to be mapped over, with each leaf providing the second positional argument and the corresponding accessor providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
The original
treewith the values at each leaf modified by the side effect of functionfunc(a, x, *xs)(not the return value) where(a, x)are the accessor and value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_replace_nones(sentinel, tree, /, namespace='')[source]
Replace
Noneintreewithsentinel.See also
tree_flatten()andtree_map().>>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)}) {'a': 1, 'b': 0, 'c': (2, 0)} >>> tree_replace_nones(0, None) 0
- Parameters:
- Returns:
A new pytree with the same structure as
treebut withNonereplaced.
- optree.tree_partition(predicate, tree, /, is_leaf=None, *, fillvalue=None, none_is_leaf=False, namespace='')[source]
Partition a tree into the left and right parts by the given predicate function.
See also
tree_transpose_map().>>> left, right = tree_partition(lambda x: x > 10, {'x': 7, 'y': (42, 64)}) >>> left {'x': None, 'y': (42, 64)} >>> right {'x': 7, 'y': (None, None)}
Instead of
None, one can also use a different sentinel value:>>> sentinel = object() >>> left, right = tree_partition(lambda x: x > 10, {'x': 7, 'y': (42, 64)}, fillvalue=sentinel) >>> left {'x': <object object at ...>, 'y': (42, 64)} >>> right {'x': 7, 'y': (<object object at ...>, <object object at ...>)}
- Parameters:
predicate (callable) – A function that takes a leaf value as argument and splits/partitions it into the left or right tree based on the
predicate’s return value.tree (pytree) – A pytree to be split, with each leaf providing the first positional argument to function
predicate.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.fillvalue (object, optional) – A sentinel value to retain the tree structure. (default:
None)none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Returns:
Two pytrees with the same structure as
treebut with orthogonal leaves based on thepredicatefunction. The first pytree contains all leaves wherepredicateevaluates toTrue, the second forFalse. The removed nodes in both trees are filled withfillvalueto keep the original tree structure.
- optree.tree_transpose(outer_treespec, inner_treespec, tree, /, is_leaf=None)[source]
Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
See also
tree_flatten(),tree_structure(), andtree_transpose_map().>>> outer_treespec = tree_structure({'a': 1, 'b': 2, 'c': (3, 4)}) >>> outer_treespec PyTreeSpec({'a': *, 'b': *, 'c': (*, *)}) >>> inner_treespec = tree_structure((1, 2)) >>> inner_treespec PyTreeSpec((*, *)) >>> tree = {'a': (1, 2), 'b': (3, 4), 'c': ((5, 6), (7, 8))} >>> tree_transpose(outer_treespec, inner_treespec, tree) ({'a': 1, 'b': 3, 'c': (5, 7)}, {'a': 2, 'b': 4, 'c': (6, 8)})
For performance reasons, this function is only checks for the number of leaves in the input pytree, not the structure. The result is only enumerated up to the original order of leaves in
tree, then transpose depends on the number of leaves in structure (inner, outer). The caller is responsible for ensuring that the input pytree has a prefix structure ofouter_treespecfollowed by a prefix structure ofinner_treespec. Otherwise, the result may be incorrect.>>> tree_transpose(outer_treespec, inner_treespec, list(range(1, 9))) ({'a': 1, 'b': 3, 'c': (5, 7)}, {'a': 2, 'b': 4, 'c': (6, 8)})
- Parameters:
outer_treespec (PyTreeSpec) – A treespec object representing the outer structure of the pytree.
inner_treespec (PyTreeSpec) – A treespec object representing the inner structure of the pytree.
tree (pytree) – A pytree to be transposed.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.
- Return type:
- Returns:
A new pytree with the same structure as
inner_treespecbut with the value at each leaf having the same structure asouter_treespec.
- optree.tree_transpose_map(func, tree, /, *rests, inner_treespec=None, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args to produce a new pytree with transposed structure.
See also
tree_map(),tree_map_with_path(), andtree_transpose().>>> comp = {'a': 1, 'b': (6j, -3 + 4j), 'c': [-5.0, 2j]} >>> real, imag, mod = tree_transpose_map(lambda z: (z.real, z.imag, abs(z)), comp) >>> real {'a': 1, 'b': (0.0, -3.0), 'c': [-5.0, 0.0]} >>> imag {'a': 0, 'b': (6.0, 4.0), 'c': [0.0, 2.0]} >>> mod {'a': 1, 'b': (6.0, 5.0), 'c': [5.0, 2.0]}
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} >>> tree_transpose_map( ... lambda x: {'identity': x, 'double': 2 * x}, ... tree, ... ) { 'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, 'double': {'b': (4, [6, 8]), 'a': 2, 'c': (10, 12)} } >>> tree_transpose_map( ... lambda x: {'identity': x, 'double': (x, x)}, ... tree, ... ) { 'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, 'double': ( {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} ) } >>> tree_transpose_map( ... lambda x: {'identity': x, 'double': (x, x)}, ... tree, ... inner_treespec=tree_structure({'identity': 0, 'double': 0}), ... ) { 'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, 'double': {'b': ((2, 2), [(3, 3), (4, 4)]), 'a': (1, 1), 'c': ((5, 5), (6, 6))} }
- Parameters:
func (callable) – A function that takes
1 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees.tree (pytree) – A pytree to be mapped over, with each leaf providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.inner_treespec (PyTreeSpec, optional) – The treespec object representing the inner structure of the result pytree. If not specified, the inner structure is inferred from the result of the function
funcon the first leaf. (default:None)is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new nested pytree with the same structure as
inner_treespecbut with the value at each leaf having the same structure astree. The subtree at each leaf is given by the result of functionfunc(x, *xs)wherexis the value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_transpose_map_with_path(func, tree, /, *rests, inner_treespec=None, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args as well as the tree paths to produce a new pytree with transposed structure.
See also
tree_map_with_path(),tree_transpose_map(), andtree_transpose().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} >>> tree_transpose_map_with_path( ... lambda p, x: {'depth': len(p), 'value': x}, ... tree, ... ) { 'depth': {'b': (2, [3, 3]), 'a': 1, 'c': (2, 2)}, 'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} } >>> tree_transpose_map_with_path( ... lambda p, x: {'path': p, 'value': x}, ... tree, ... inner_treespec=tree_structure({'path': 0, 'value': 0}), ... ) { 'path': { 'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), 'a': ('a',), 'c': (('c', 0), ('c', 1)) }, 'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} }
- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra paths.tree (pytree) – A pytree to be mapped over, with each leaf providing the second positional argument and the corresponding path providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.inner_treespec (PyTreeSpec, optional) – The treespec object representing the inner structure of the result pytree. If not specified, the inner structure is inferred from the result of the function
funcon the first leaf. (default:None)is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new nested pytree with the same structure as
inner_treespecbut with the value at each leaf having the same structure astree. The subtree at each leaf is given by the result of functionfunc(p, x, *xs)where(p, x)are the path and value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_transpose_map_with_accessor(func, tree, /, *rests, inner_treespec=None, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree with transposed structure.
See also
tree_map_with_accessor(),tree_transpose_map(), andtree_transpose().>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} >>> tree_transpose_map_with_accessor( ... lambda a, x: {'depth': len(a), 'code': a.codify('tree'), 'value': x}, ... tree, ... ) { 'depth': { 'b': (2, [3, 3]), 'a': 1, 'c': (2, 2) }, 'code': { 'b': ("tree['b'][0]", ["tree['b'][1][0]", "tree['b'][1][1]"]), 'a': "tree['a']", 'c': ("tree['c'][0]", "tree['c'][1]") }, 'value': { 'b': (2, [3, 4]), 'a': 1, 'c': (5, 6) } } >>> tree_transpose_map_with_accessor( ... lambda a, x: {'path': a.path, 'accessor': a, 'value': x}, ... tree, ... inner_treespec=tree_structure({'path': 0, 'accessor': 0, 'value': 0}), ... ) { 'path': { 'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), 'a': ('a',), 'c': (('c', 0), ('c', 1)) }, 'accessor': { 'b': ( PyTreeAccessor(*['b'][0], ...), [ PyTreeAccessor(*['b'][1][0], ...), PyTreeAccessor(*['b'][1][1], ...) ] ), 'a': PyTreeAccessor(*['a'], ...), 'c': ( PyTreeAccessor(*['c'][0], ...), PyTreeAccessor(*['c'][1], ...) ) }, 'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} }
- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra accessors.tree (pytree) – A pytree to be mapped over, with each leaf providing the second positional argument and the corresponding accessor providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.inner_treespec (PyTreeSpec, optional) – The treespec object representing the inner structure of the result pytree. If not specified, the inner structure is inferred from the result of the function
funcon the first leaf. (default:None)is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new nested pytree with the same structure as
inner_treespecbut with the value at each leaf having the same structure astree. The subtree at each leaf is given by the result of functionfunc(a, x, *xs)where(a, x)are the accessor and value at the corresponding leaf intreeandxsis the tuple of values at corresponding nodes inrests.
- optree.tree_broadcast_prefix(prefix_tree, full_tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Return a pytree of same structure of
full_treewith broadcasted subtrees inprefix_tree.See also
broadcast_prefix(),tree_broadcast_common(), andtreespec_is_prefix().If a
prefix_treeis a prefix of afull_tree, this means thefull_treecan be constructed by replacing the leaves ofprefix_treewith appropriate subtrees.This function returns a pytree with the same size as
full_tree. The leaves are replicated fromprefix_tree. The number of replicas is determined by the corresponding subtree infull_tree.>>> tree_broadcast_prefix(1, [2, 3, 4]) [1, 1, 1] >>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6]) [1, 2, 3] >>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) Traceback (most recent call last): ... ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7]. >>> tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) [1, 2, (3, 3)] >>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) [1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}] >>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True) [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]
- Parameters:
prefix_tree (pytree) – A pytree with the prefix structure of
full_tree.full_tree (pytree) – A pytree with the suffix structure of
prefix_tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A pytree of same structure of
full_treewith broadcasted subtrees inprefix_tree.
- optree.broadcast_prefix(prefix_tree, full_tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Return a list of broadcasted leaves in
prefix_treeto match the number of leaves infull_tree.See also
tree_broadcast_prefix(),broadcast_common(), andtreespec_is_prefix().If a
prefix_treeis a prefix of afull_tree, this means thefull_treecan be constructed by replacing the leaves ofprefix_treewith appropriate subtrees.This function returns a list of leaves with the same size as
full_tree. The leaves are replicated fromprefix_tree. The number of replicas is determined by the corresponding subtree infull_tree.>>> broadcast_prefix(1, [2, 3, 4]) [1, 1, 1] >>> broadcast_prefix([1, 2, 3], [4, 5, 6]) [1, 2, 3] >>> broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) Traceback (most recent call last): ... ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7]. >>> broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) [1, 2, 3, 3] >>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) [1, 2, 3, 3, 3] >>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True) [1, 2, 3, 3, 3, 3]
- Parameters:
prefix_tree (pytree) – A pytree with the prefix structure of
full_tree.full_tree (pytree) – A pytree with the suffix structure of
prefix_tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A list of leaves in
prefix_treebroadcasted to match the number of leaves infull_tree.
- optree.tree_broadcast_common(tree, other_tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Return two pytrees of common suffix structure of
treeandother_treewith broadcasted subtrees.See also
broadcast_common(),tree_broadcast_prefix(), andtreespec_is_prefix().If a
suffix_treeis a suffix of atree, this means thesuffix_treecan be constructed by replacing the leaves oftreewith appropriate subtrees.This function returns two pytrees with the same structure. The tree structure is the common suffix structure of
treeandother_tree. The leaves are replicated fromtreeandother_tree. The number of replicas is determined by the corresponding subtree in the suffix structure.>>> tree_broadcast_common(1, [2, 3, 4]) ([1, 1, 1], [2, 3, 4]) >>> tree_broadcast_common([1, 2, 3], [4, 5, 6]) ([1, 2, 3], [4, 5, 6]) >>> tree_broadcast_common([1, 2, 3], [4, 5, 6, 7]) Traceback (most recent call last): ... ValueError: list arity mismatch; expected: 3, got: 4. >>> tree_broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) ([1, (2, 3), (4, 4)], [5, (6, 6), (7, 8)]) >>> tree_broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}]) ([1, {'a': (2, 3)}, {'a': 4, 'b': 4, 'c': (None, 4)}], [5, {'a': (6, 6)}, {'a': 7, 'b': 8, 'c': (None, 9)}]) >>> tree_broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], none_is_leaf=True) ([1, {'a': (2, 3)}, {'a': 4, 'b': 4, 'c': (4, 4)}], [5, {'a': (6, 6)}, {'a': 7, 'b': 8, 'c': (None, 9)}]) >>> tree_broadcast_common([1, None], [None, 2]) ([None, None], [None, None]) >>> tree_broadcast_common([1, None], [None, 2], none_is_leaf=True) ([1, None], [None, 2])
- Parameters:
tree (pytree) – A pytree that has a common suffix structure with
other_tree.other_tree (pytree) – A pytree that has a common suffix structure with
tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
Two pytrees of common suffix structure of
treeandother_treewith broadcasted subtrees.
- optree.broadcast_common(tree, other_tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Return two lists of broadcasted leaves in
treeandother_treeto match the number of leaves in the common suffix structure.See also
tree_broadcast_common(),broadcast_prefix(), andtreespec_is_prefix().If a
suffix_treeis a suffix of atree, this means thesuffix_treecan be constructed by replacing the leaves oftreewith appropriate subtrees.This function returns two pytrees with the same structure. The tree structure is the common suffix structure of
treeandother_tree. The leaves are replicated fromtreeandother_tree. The number of replicas is determined by the corresponding subtree in the suffix structure.>>> broadcast_common(1, [2, 3, 4]) ([1, 1, 1], [2, 3, 4]) >>> broadcast_common([1, 2, 3], [4, 5, 6]) ([1, 2, 3], [4, 5, 6]) >>> broadcast_common([1, 2, 3], [4, 5, 6, 7]) Traceback (most recent call last): ... ValueError: list arity mismatch; expected: 3, got: 4. >>> broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) ([1, 2, 3, 4, 4], [5, 6, 6, 7, 8]) >>> broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}]) ([1, 2, 3, 4, 4, 4], [5, 6, 6, 7, 8, 9]) >>> broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], none_is_leaf=True) ([1, 2, 3, 4, 4, 4, 4], [5, 6, 6, 7, 8, None, 9]) >>> broadcast_common([1, None], [None, 2]) ([], []) >>> broadcast_common([1, None], [None, 2], none_is_leaf=True) ([1, None], [None, 2])
- Parameters:
tree (pytree) – A pytree that has a common suffix structure with
other_tree.other_tree (pytree) – A pytree that has a common suffix structure with
tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
Two lists of leaves in
treeandother_treebroadcasted to match the number of leaves in the common suffix structure.
- optree.tree_broadcast_map(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args to produce a new pytree.
See also
tree_broadcast_map_with_path(),tree_map(),tree_map_(), andtree_map_with_path().If only one input is provided, this function is the same as
tree_map():>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) {'x': 8, 'y': (43, 65)} >>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (43, 65), 'z': None} >>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) {'x': False, 'y': (False, False), 'z': None} >>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True) {'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure of all inputs:
>>> tree_broadcast_map(lambda x, y: x * y, [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8]) [{'a': 35, 'b': 45}, [6, 12], (24, 32)]
- Parameters:
func (callable) – A function that takes
1 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees.tree (pytree) – A pytree to be mapped over, with each leaf providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, they should have a common suffix structure with each other and with
tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new pytree with the structure as the common suffix structure of
treeandrestsbut with the value at each leaf given byfunc(x, *xs)wherexis the value at the corresponding leaf (may be broadcasted) intreeandxsis the tuple of values at corresponding leaves (may be broadcasted) inrests.
- optree.tree_broadcast_map_with_path(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
See also
tree_broadcast_map(),tree_map(),tree_map_(), andtree_map_with_path().If only one input is provided, this function is the same as
tree_map():>>> tree_broadcast_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)}) {'x': (1, 7), 'y': ((2, 42), (2, 64))} >>> tree_broadcast_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (44, 66), 'z': None} >>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}) {'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}} >>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True) {'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}}
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure of all inputs:
>>> tree_broadcast_map_with_path( ... lambda p, x, y: (p, x * y), ... [5, 6, (3, 4)], ... [{'a': 7, 'b': 9}, [1, 2], 8], ... ) [ {'a': ((0, 'a'), 35), 'b': ((0, 'b'), 45)}, [((1, 0), 6), ((1, 1), 12)], (((2, 0), 24), ((2, 1), 32)) ]
- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra paths.tree (pytree) – A pytree to be mapped over, with each leaf providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, they should have a common suffix structure with each other and with
tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new pytree with the structure as the common suffix structure of
treeandrestsbut with the value at each leaf given byfunc(p, x, *xs)where(p, x)are the path and value at the corresponding leaf (may be broadcasted) intreeandxsis the tuple of values at corresponding leaves (may be broadcasted) inrests.
- optree.tree_broadcast_map_with_accessor(func, tree, /, *rests, is_leaf=None, none_is_leaf=False, namespace='')[source]
Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree.
See also
tree_broadcast_map(),tree_map(),tree_map_(), andtree_map_with_accessor().If only one input is provided, this function is the same as
tree_map():>>> tree_broadcast_map_with_accessor(lambda a, x: (len(a), x), {'x': 7, 'y': (42, 64)}) {'x': (1, 7), 'y': ((2, 42), (2, 64))} >>> tree_broadcast_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None}) {'x': 8, 'y': (44, 66), 'z': None} >>> tree_broadcast_map_with_accessor( ... lambda a, x: a.codify('tree'), ... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, ... ) { 'x': "tree['x']", 'y': ("tree['y'][0]", "tree['y'][1]"), 'z': {1.5: None} } >>> tree_broadcast_map_with_accessor( ... lambda a, x: a.codify('tree'), ... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, ... none_is_leaf=True, ... ) { 'x': "tree['x']", 'y': ("tree['y'][0]", "tree['y'][1]"), 'z': {1.5: "tree['z'][1.5]"} }
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure of all inputs:
>>> tree_broadcast_map_with_accessor( ... lambda a, x, y: f'{a.codify("tree")} = {x * y}', ... [5, 6, (3, 4)], ... [{'a': 7, 'b': 9}, [1, 2], 8], ... ) [ {'a': "tree[0]['a'] = 35", 'b': "tree[0]['b'] = 45"}, ['tree[1][0] = 6', 'tree[1][1] = 12'], ('tree[2][0] = 24', 'tree[2][1] = 32') ]
- Parameters:
func (callable) – A function that takes
2 + len(rests)arguments, to be applied at the corresponding leaves of the pytrees with extra accessors.tree (pytree) – A pytree to be mapped over, with each leaf providing the first positional argument to function
func.rests (tuple of pytree) – A tuple of pytrees, they should have a common suffix structure with each other and with
tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A new pytree with the structure as the common suffix structure of
treeandrestsbut with the value at each leaf given byfunc(a, x, *xs)where(a, x)are the accessor and value at the corresponding leaf (may be broadcasted) intreeandxsis the tuple of values at corresponding leaves (may be broadcasted) inrests.
- optree.tree_reduce(func, tree, /, initial=<MISSING>, *, is_leaf=None, none_is_leaf=False, namespace='')[source]
Traversal through a pytree and reduce the leaves in left-to-right depth-first order.
See also
tree_leaves()andtree_sum().>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)}) 6 >>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default 6 >>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) 3 >>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) None
- Parameters:
func (callable) – A function that takes two arguments and returns a value of the same type.
tree (pytree) – A pytree to be traversed.
initial (object, optional) – An initial value to be used for the reduction. If not provided, the first leaf value is used as the initial value.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
TypeVar(T)- Returns:
The result of reducing the leaves of the pytree using
func.
- optree.tree_sum(tree, /, start=0, *, is_leaf=None, none_is_leaf=False, namespace='')[source]
Sum
startand leaf values intreein left-to-right depth-first order and return the total.See also
tree_leaves()andtree_reduce().>>> tree_sum({'x': 1, 'y': (2, 3)}) 6 >>> tree_sum({'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default 6 >>> tree_sum({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) Traceback (most recent call last): ... TypeError: unsupported operand type(s) for +: 'int' and 'NoneType' >>> tree_sum({'x': 'a', 'y': ('b', None), 'z': 'c'}, start='') 'abc' >>> tree_sum({'x': [1], 'y': ([2], [None]), 'z': [3]}, start=[], is_leaf=lambda x: isinstance(x, list)) [1, 2, None, 3]
- Parameters:
tree (pytree) – A pytree to be traversed.
start (object, optional) – An initial value to be used for the sum. (default:
0)is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
TypeVar(T)- Returns:
The total sum of
startand leaf values intree.
- optree.tree_max(tree, /, *, default=<MISSING>, key=None, is_leaf=None, none_is_leaf=False, namespace='')[source]
Return the maximum leaf value in
tree.See also
tree_leaves()andtree_min().>>> tree_max({}) Traceback (most recent call last): ... ValueError: max() iterable argument is empty >>> tree_max({}, default=0) 0 >>> tree_max({'x': 0, 'y': (2, 1)}) 2 >>> tree_max({'x': 0, 'y': (2, 1)}, key=lambda x: -x) 0 >>> tree_max({'a': None}) # `None` is a non-leaf node with arity 0 by default Traceback (most recent call last): ... ValueError: max() iterable argument is empty >>> tree_max({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default 0 >>> tree_max({'a': None}, none_is_leaf=True) None >>> tree_max(None) # `None` is a non-leaf node with arity 0 by default Traceback (most recent call last): ... ValueError: max() iterable argument is empty >>> tree_max(None, default=0) 0 >>> tree_max(None, none_is_leaf=True) None
- Parameters:
tree (pytree) – A pytree to be traversed.
default (object, optional) – The default value to return if
treeis empty. If thetreeis empty anddefaultis not specified, raise aValueError.key (callable or None, optional) – A one-argument ordering function like that used for
list.sort().is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
TypeVar(T)- Returns:
The maximum leaf value in
tree.
- optree.tree_min(tree, /, *, default=<MISSING>, key=None, is_leaf=None, none_is_leaf=False, namespace='')[source]
Return the minimum leaf value in
tree.See also
tree_leaves()andtree_max().>>> tree_min({}) Traceback (most recent call last): ... ValueError: min() iterable argument is empty >>> tree_min({}, default=0) 0 >>> tree_min({'x': 0, 'y': (2, 1)}) 0 >>> tree_min({'x': 0, 'y': (2, 1)}, key=lambda x: -x) 2 >>> tree_min({'a': None}) # `None` is a non-leaf node with arity 0 by default Traceback (most recent call last): ... ValueError: min() iterable argument is empty >>> tree_min({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default 0 >>> tree_min({'a': None}, none_is_leaf=True) None >>> tree_min(None) # `None` is a non-leaf node with arity 0 by default Traceback (most recent call last): ... ValueError: min() iterable argument is empty >>> tree_min(None, default=0) 0 >>> tree_min(None, none_is_leaf=True) None
- Parameters:
tree (pytree) – A pytree to be traversed.
default (object, optional) – The default value to return if
treeis empty. If thetreeis empty anddefaultis not specified, raise aValueError.key (callable or None, optional) – A one-argument ordering function like that used for
list.sort().is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
TypeVar(T)- Returns:
The minimum leaf value in
tree.
- optree.tree_all(tree, /, *, is_leaf=None, none_is_leaf=False, namespace='')[source]
Test whether all leaves in
treeare true (or iftreeis empty).See also
tree_leaves()andtree_any().>>> tree_all({}) True >>> tree_all({'x': 1, 'y': (2, 3)}) True >>> tree_all({'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default True >>> tree_all({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) False >>> tree_all(None) # `None` is a non-leaf node with arity 0 by default True >>> tree_all(None, none_is_leaf=True) False
- Parameters:
tree (pytree) – A pytree to be traversed.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
Trueif all leaves intreeare true, or iftreeis empty. Otherwise,False.
- optree.tree_any(tree, /, *, is_leaf=None, none_is_leaf=False, namespace='')[source]
Test whether any leaves in
treeare true (orFalseiftreeis empty).See also
tree_leaves()andtree_all().>>> tree_any({}) False >>> tree_any({'x': 0, 'y': (2, 0)}) True >>> tree_any({'a': None}) # `None` is a non-leaf node with arity 0 by default False >>> tree_any({'a': None}, none_is_leaf=True) # `None` is evaluated as false False >>> tree_any(None) # `None` is a non-leaf node with arity 0 by default False >>> tree_any(None, none_is_leaf=True) # `None` is evaluated as false False
- Parameters:
tree (pytree) – A pytree to be traversed.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
Trueif any leaves intreeare true. Otherwise,False. Iftreeis empty, returnFalse.
- optree.tree_flatten_one_level(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Flatten the pytree one level, returning a 4-tuple of children, metadata, path entries, and an unflatten function.
See also
tree_flatten()andtree_flatten_with_path().>>> children, metadata, entries, unflatten_func = tree_flatten_one_level({'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}) >>> children, metadata, entries ([1, (2, [3, 4]), None, 5], ['a', 'b', 'c', 'd'], ('a', 'b', 'c', 'd')) >>> unflatten_func(metadata, children) {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> children, metadata, entries, unflatten_func = tree_flatten_one_level([{'a': 1, 'b': (2, 3)}, (4, 5)]) >>> children, metadata, entries ([{'a': 1, 'b': (2, 3)}, (4, 5)], None, (0, 1)) >>> unflatten_func(metadata, children) [{'a': 1, 'b': (2, 3)}, (4, 5)]
- Parameters:
tree (pytree) – A pytree to be traversed.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
FlattenOneLevelOutputEx[TypeVar(T)]- Returns:
A 4-tuple
(children, metadata, entries, unflatten_func). The first element is a list of one-level children of the pytree node. The second element is the metadata used to reconstruct the pytree node. The third element is a tuple of path entries to the children. The fourth element is a function that can be used to unflatten the metadata and children back to the pytree node.
- optree.prefix_errors(prefix_tree, full_tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Return a list of errors that would be raised by
broadcast_prefix().See also
broadcast_prefix()andtree_broadcast_prefix().- Parameters:
prefix_tree (pytree) – A pytree with the prefix structure of
full_tree.full_tree (pytree) – A pytree with the suffix structure of
prefix_tree.is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
list[Callable[[str],ValueError]]- Returns:
A list of callables that take a name string and return a
ValueErrordescribing the structure mismatch. An empty list indicates thatprefix_treeis a valid prefix offull_tree.
- optree.treespec_paths(treespec, /)[source]
Return a list of paths to the leaves of a treespec.
See also
tree_flatten_with_path(),tree_paths(), andPyTreeSpec.paths().>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_paths(treespec) [('a', 0), ('a', 1, 0), ('a', 1, 1), ('b',), ('c', 0)]
- optree.treespec_accessors(treespec, /)[source]
Return a list of accessors to the leaves of a treespec.
See also
tree_flatten_with_accessor(),tree_accessors(), andPyTreeSpec.accessors().- Return type:
list[PyTreeAccessor]
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_accessors(treespec) [ PyTreeAccessor(*['a'][0], ...), PyTreeAccessor(*['a'][1][0], ...), PyTreeAccessor(*['a'][1][1], ...), PyTreeAccessor(*['b'], ...), PyTreeAccessor(*['c'][0], ...) ] >>> treespec_accessors(treespec_leaf()) [PyTreeAccessor(*, ())] >>> treespec_accessors(treespec_none()) []
- optree.treespec_entries(treespec, /)[source]
Return a list of one-level entries of a treespec to its children.
See also
treespec_entry(),treespec_paths(),treespec_children(), andPyTreeSpec.entries().>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_entries(treespec) ['a', 'b', 'c']
- optree.treespec_entry(treespec, index, /)[source]
Return the entry of a treespec at the given index.
See also
treespec_entries(),treespec_children(), andPyTreeSpec.entry().- Return type:
- optree.treespec_children(treespec, /)[source]
Return a list of treespecs for the children of a treespec.
See also
treespec_child(),treespec_paths(),treespec_entries(),treespec_one_level(), andPyTreeSpec.children().- Return type:
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_children(treespec) [PyTreeSpec((*, [*, *])), PyTreeSpec(*), PyTreeSpec((*, None))]
- optree.treespec_child(treespec, index, /)[source]
Return the treespec of the child of a treespec at the given index.
See also
treespec_children(),treespec_entries(), andPyTreeSpec.child().- Return type:
- optree.treespec_one_level(treespec, /)[source]
Return the one-level tree structure of the treespec or
Noneif the treespec is a leaf.See also
treespec_children(),treespec_is_one_level(), andPyTreeSpec.one_level().- Return type:
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_one_level(treespec) PyTreeSpec({'a': *, 'b': *, 'c': *})
- optree.treespec_transform(treespec, /, f_node=None, f_leaf=None)[source]
Transform a treespec by applying functions to its nodes and leaves.
See also
treespec_children(),treespec_is_leaf(), andPyTreeSpec.transform().>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) >>> treespec PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) >>> treespec_transform(treespec, lambda spec: treespec_dict(zip(spec.entries(), spec.children()))) PyTreeSpec({'a': {0: *, 1: {0: *, 1: *}}, 'b': *, 'c': {0: *, 1: {}}}) >>> treespec_transform( ... treespec, ... lambda spec: ( ... treespec_ordereddict(zip(spec.entries(), spec.children())) ... if spec.type is dict ... else spec ... ), ... ) PyTreeSpec(OrderedDict({'a': (*, [*, *]), 'b': *, 'c': (*, None)})) >>> treespec_transform( ... treespec, ... lambda spec: ( ... treespec_ordereddict(tree_unflatten(spec, spec.children())) ... if spec.type is dict ... else spec ... ), ... ) PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': (*, None)})) >>> treespec_transform(treespec, lambda spec: treespec_tuple(spec.children())) PyTreeSpec(((*, (*, *)), *, (*, ()))) >>> treespec_transform( ... treespec, ... lambda spec: ( ... treespec_list(spec.children()) ... if spec.type is tuple ... else spec ... ), ... ) PyTreeSpec({'a': [*, [*, *]], 'b': *, 'c': [*, None]}) >>> treespec_transform(treespec, None, lambda spec: tree_structure((1, [2]))) PyTreeSpec({'a': ((*, [*]), [(*, [*]), (*, [*])]), 'b': (*, [*]), 'c': ((*, [*]), None)})
- Parameters:
treespec (PyTreeSpec) – A treespec to transform.
f_node (callable, optional) – A function to apply to each non-leaf node. It takes a treespec and returns a new treespec. If
None, the node is left unchanged. (default:None)f_leaf (callable, optional) – A function to apply to each leaf node. It takes a treespec and returns a new treespec. If
None, the leaf is left unchanged. (default:None)
- Return type:
- Returns:
A new treespec with the transformations applied.
- optree.treespec_is_leaf(treespec, /, *, strict=True)[source]
Return whether the treespec is a leaf that has no children.
See also
treespec_is_strict_leaf()andPyTreeSpec.is_leaf().This function is equivalent to
treespec.is_leaf(strict=strict). Ifstrict=True, it will returnTrueif and only if the treespec represents a strict leaf. Ifstrict=False, it will returnTrueif the treespec represents a strict leaf orNoneor an empty container (e.g., an empty tuple).>>> treespec_is_leaf(tree_structure(1)) True >>> treespec_is_leaf(tree_structure((1, 2))) False >>> treespec_is_leaf(tree_structure(None)) False >>> treespec_is_leaf(tree_structure(None), strict=False) True >>> treespec_is_leaf(tree_structure(None, none_is_leaf=False)) False >>> treespec_is_leaf(tree_structure(None, none_is_leaf=True)) True >>> treespec_is_leaf(tree_structure(())) False >>> treespec_is_leaf(tree_structure(()), strict=False) True >>> treespec_is_leaf(tree_structure([])) False >>> treespec_is_leaf(tree_structure([]), strict=False) True
- optree.treespec_is_strict_leaf(treespec, /)[source]
Return whether the treespec is a strict leaf.
See also
treespec_is_leaf()andPyTreeSpec.is_leaf().This function respects the
none_is_leafsetting in the treespec. It is equivalent totreespec.is_leaf(strict=True). It will returnTrueif and only if the treespec represents a strict leaf.>>> treespec_is_strict_leaf(tree_structure(1)) True >>> treespec_is_strict_leaf(tree_structure((1, 2))) False >>> treespec_is_strict_leaf(tree_structure(None)) False >>> treespec_is_strict_leaf(tree_structure(None, none_is_leaf=False)) False >>> treespec_is_strict_leaf(tree_structure(None, none_is_leaf=True)) True >>> treespec_is_strict_leaf(tree_structure(())) False >>> treespec_is_strict_leaf(tree_structure([])) False
- Parameters:
treespec (PyTreeSpec) – A treespec.
- Return type:
- Returns:
Trueif the treespec represents a strict leaf, otherwise,False.
- optree.treespec_is_one_level(treespec, /)[source]
Return whether the treespec is a one-level tree structure.
See also
treespec_is_leaf(),treespec_one_level(), andPyTreeSpec.is_one_level().- Return type:
>>> treespec_is_one_level(tree_structure(1)) False >>> treespec_is_one_level(tree_structure((1, 2))) True >>> treespec_is_one_level(tree_structure({'a': 1, 'b': 2, 'c': 3})) True >>> treespec_is_one_level(tree_structure({'a': 1, 'b': (2, 3), 'c': 4})) False >>> treespec_is_one_level(tree_structure(None)) True
- optree.treespec_is_prefix(treespec, other_treespec, /, *, strict=False)[source]
Return whether
treespecis a prefix ofother_treespec.See also
treespec_is_suffix()andPyTreeSpec.is_prefix().- Parameters:
treespec (PyTreeSpec) – A treespec.
other_treespec (PyTreeSpec) – Another treespec to compare against.
strict (bool, optional) – If
True, the treespec must be a strict prefix (not equal). (default:False)
- Return type:
- Returns:
Trueiftreespecis a prefix ofother_treespec, otherwise,False.
- optree.treespec_is_suffix(treespec, other_treespec, /, *, strict=False)[source]
Return whether
treespecis a suffix ofother_treespec.See also
treespec_is_prefix()andPyTreeSpec.is_suffix().- Parameters:
treespec (PyTreeSpec) – A treespec.
other_treespec (PyTreeSpec) – Another treespec to compare against.
strict (bool, optional) – If
True, the treespec must be a strict suffix (not equal). (default:False)
- Return type:
- Returns:
Trueiftreespecis a suffix ofother_treespec, otherwise,False.
- optree.treespec_leaf(*, none_is_leaf=False, namespace='')[source]
Make a treespec representing a leaf node.
See also
tree_structure(),treespec_none(), andtreespec_tuple().>>> treespec_leaf() PyTreeSpec(*) >>> treespec_leaf(none_is_leaf=True) PyTreeSpec(*, NoneIsLeaf) >>> treespec_leaf(none_is_leaf=False) == treespec_leaf(none_is_leaf=True) False >>> treespec_leaf() == tree_structure(1) True >>> treespec_leaf(none_is_leaf=True) == tree_structure(1, none_is_leaf=True) True >>> treespec_leaf(none_is_leaf=True) == tree_structure(None, none_is_leaf=True) True >>> treespec_leaf(none_is_leaf=True) == tree_structure(None, none_is_leaf=False) False >>> treespec_leaf(none_is_leaf=True) == treespec_none(none_is_leaf=True) True >>> treespec_leaf(none_is_leaf=True) == treespec_none(none_is_leaf=False) False >>> treespec_leaf(none_is_leaf=False) == treespec_none(none_is_leaf=True) False >>> treespec_leaf(none_is_leaf=False) == treespec_none(none_is_leaf=False) False
- Parameters:
none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing a leaf node.
- optree.treespec_none(*, none_is_leaf=False, namespace='')[source]
Make a treespec representing a
Nonenode.See also
tree_structure(),treespec_leaf(), andtreespec_tuple().>>> treespec_none() PyTreeSpec(None) >>> treespec_none(none_is_leaf=True) PyTreeSpec(*, NoneIsLeaf) >>> treespec_none(none_is_leaf=False) == treespec_none(none_is_leaf=True) False >>> treespec_none() == tree_structure(None) True >>> treespec_none() == tree_structure(1) False >>> treespec_none(none_is_leaf=True) == tree_structure(1, none_is_leaf=True) True >>> treespec_none(none_is_leaf=True) == tree_structure(None, none_is_leaf=True) True >>> treespec_none(none_is_leaf=True) == tree_structure(None, none_is_leaf=False) False >>> treespec_none(none_is_leaf=True) == treespec_leaf(none_is_leaf=True) True >>> treespec_none(none_is_leaf=False) == treespec_leaf(none_is_leaf=True) False >>> treespec_none(none_is_leaf=True) == treespec_leaf(none_is_leaf=False) False >>> treespec_none(none_is_leaf=False) == treespec_leaf(none_is_leaf=False) False
- Parameters:
none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing a
Nonenode.
- optree.treespec_tuple(iterable=(), /, *, none_is_leaf=False, namespace='')[source]
Make a tuple treespec from an iterable of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_tuple([treespec_leaf(), treespec_leaf()]) PyTreeSpec((*, *)) >>> treespec_tuple([treespec_leaf(), treespec_leaf(), treespec_none()]) PyTreeSpec((*, *, None)) >>> treespec_tuple() PyTreeSpec(()) >>> treespec_tuple([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) PyTreeSpec((*, (*, *))) >>> treespec_tuple([treespec_leaf(), tree_structure({'a': 1, 'b': 2})]) PyTreeSpec((*, {'a': *, 'b': *})) >>> treespec_tuple([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)]) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
iterable (iterable of PyTreeSpec, optional) – An iterable of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing a tuple node with the given children.
- optree.treespec_list(iterable=(), /, *, none_is_leaf=False, namespace='')[source]
Make a list treespec from an iterable of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_list([treespec_leaf(), treespec_leaf()]) PyTreeSpec([*, *]) >>> treespec_list([treespec_leaf(), treespec_leaf(), treespec_none()]) PyTreeSpec([*, *, None]) >>> treespec_list() PyTreeSpec([]) >>> treespec_list([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) PyTreeSpec([*, (*, *)]) >>> treespec_list([treespec_leaf(), tree_structure({'a': 1, 'b': 2})]) PyTreeSpec([*, {'a': *, 'b': *}]) >>> treespec_list([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)]) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
iterable (iterable of PyTreeSpec, optional) – An iterable of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing a list node with the given children.
- optree.treespec_dict(mapping=(), /, *, none_is_leaf=False, namespace='', **kwargs)[source]
Make a dict treespec from a dict of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_dict({'a': treespec_leaf(), 'b': treespec_leaf()}) PyTreeSpec({'a': *, 'b': *}) >>> treespec_dict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())]) PyTreeSpec({'a': None, 'b': *, 'c': *}) >>> treespec_dict() PyTreeSpec({}) >>> treespec_dict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()])) PyTreeSpec({'a': *, 'b': (*, *)}) >>> treespec_dict({'a': treespec_leaf(), 'b': tree_structure([1, 2])}) PyTreeSpec({'a': *, 'b': [*, *]}) >>> treespec_dict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)}) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
mapping (mapping of PyTreeSpec, optional) – A mapping of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)**kwargs (PyTreeSpec, optional) – Additional child treespecs to add to the mapping.
- Return type:
- Returns:
A treespec representing a dict node with the given children.
- optree.treespec_namedtuple(namedtuple, /, *, none_is_leaf=False, namespace='')[source]
Make a namedtuple treespec from a namedtuple of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> from collections import namedtuple >>> Point = namedtuple('Point', ['x', 'y']) >>> treespec_namedtuple(Point(x=treespec_leaf(), y=treespec_leaf())) PyTreeSpec(Point(x=*, y=*)) >>> treespec_namedtuple(Point(x=treespec_leaf(), y=treespec_tuple([treespec_leaf(), treespec_leaf()]))) PyTreeSpec(Point(x=*, y=(*, *))) >>> treespec_namedtuple(Point(x=treespec_leaf(), y=tree_structure([1, 2]))) PyTreeSpec(Point(x=*, y=[*, *])) >>> treespec_namedtuple(Point(x=treespec_leaf(), y=tree_structure([1, 2], none_is_leaf=True))) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
namedtuple (namedtuple of PyTreeSpec) – A namedtuple of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Returns:
A treespec representing a namedtuple node with the given children.
- optree.treespec_ordereddict(mapping=(), /, *, none_is_leaf=False, namespace='', **kwargs)[source]
Make an OrderedDict treespec from an OrderedDict of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_ordereddict({'a': treespec_leaf(), 'b': treespec_leaf()}) PyTreeSpec(OrderedDict({'a': *, 'b': *})) >>> treespec_ordereddict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())]) PyTreeSpec(OrderedDict({'b': *, 'c': *, 'a': None})) >>> treespec_ordereddict() PyTreeSpec(OrderedDict()) >>> treespec_ordereddict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()])) PyTreeSpec(OrderedDict({'a': *, 'b': (*, *)})) >>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2])}) PyTreeSpec(OrderedDict({'a': *, 'b': [*, *]})) >>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)}) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
mapping (mapping of PyTreeSpec, optional) – A mapping of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)**kwargs (PyTreeSpec, optional) – Additional child treespecs to add to the mapping.
- Return type:
- Returns:
A treespec representing an OrderedDict node with the given children.
- optree.treespec_defaultdict(default_factory=None, mapping=(), /, *, none_is_leaf=False, namespace='', **kwargs)[source]
Make a defaultdict treespec from a defaultdict of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': treespec_leaf()}) PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': *})) >>> treespec_defaultdict(int, [('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())]) PyTreeSpec(defaultdict(<class 'int'>, {'a': None, 'b': *, 'c': *})) >>> treespec_defaultdict() PyTreeSpec(defaultdict(None, {})) >>> treespec_defaultdict(int) PyTreeSpec(defaultdict(<class 'int'>, {})) >>> treespec_defaultdict(int, a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()])) PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': (*, *)})) >>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': tree_structure([1, 2])}) PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': [*, *]})) >>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)}) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
default_factory (callable or None, optional) – A factory function that will be used to create a missing value. (default:
None)mapping (mapping of PyTreeSpec, optional) – A mapping of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)**kwargs (PyTreeSpec, optional) – Additional child treespecs to add to the mapping.
- Return type:
- Returns:
A treespec representing a defaultdict node with the given children.
- optree.treespec_deque(iterable=(), /, maxlen=None, *, none_is_leaf=False, namespace='')[source]
Make a deque treespec from a deque of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_deque([treespec_leaf(), treespec_leaf()]) PyTreeSpec(deque([*, *])) >>> treespec_deque([treespec_leaf(), treespec_leaf(), treespec_none()], maxlen=5) PyTreeSpec(deque([*, *, None], maxlen=5)) >>> treespec_deque() PyTreeSpec(deque()) >>> treespec_deque([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) PyTreeSpec(deque([*, (*, *)])) >>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5) PyTreeSpec(deque([*, {'a': *, 'b': *}], maxlen=5)) >>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)], maxlen=5) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
iterable (iterable of PyTreeSpec, optional) – An iterable of child treespecs. They must have the same
none_is_leafandnamespacevalues.maxlen (int or None, optional) – The maximum size of a deque or
Noneif unbounded. (default:None)none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing a deque node with the given children.
- optree.treespec_structseq(structseq, /, *, none_is_leaf=False, namespace='')[source]
Make a PyStructSequence treespec from a PyStructSequence of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().- Parameters:
structseq (PyStructSequence of PyTreeSpec) – A PyStructSequence of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing a PyStructSequence node with the given children.
- optree.treespec_from_collection(collection, /, *, none_is_leaf=False, namespace='')[source]
Make a treespec from a collection of child treespecs.
See also
tree_structure(),treespec_leaf(), andtreespec_none().>>> treespec_from_collection(None) PyTreeSpec(None) >>> treespec_from_collection(None, none_is_leaf=True) PyTreeSpec(*, NoneIsLeaf) >>> treespec_from_collection(object()) PyTreeSpec(*) >>> treespec_from_collection([treespec_leaf(), treespec_none()]) PyTreeSpec([*, None]) >>> treespec_from_collection({'a': treespec_leaf(), 'b': treespec_none()}) PyTreeSpec({'a': *, 'b': None}) >>> treespec_from_collection(deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5)) PyTreeSpec(deque([*, {'a': *, 'b': *}], maxlen=5)) >>> treespec_from_collection({'a': treespec_leaf(), 'b': (treespec_leaf(), treespec_none())}) Traceback (most recent call last): ... ValueError: Expected a(n) dict of PyTreeSpec(s), got {'a': PyTreeSpec(*), 'b': (PyTreeSpec(*), PyTreeSpec(None))}. >>> treespec_from_collection([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)]) Traceback (most recent call last): ... ValueError: Expected treespec(s) with `none_is_leaf=False`.
- Parameters:
collection (collection of PyTreeSpec) – A collection of child treespecs. They must have the same
none_is_leafandnamespacevalues.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A treespec representing the same structure of the collection with the given children.
- class optree.PyTreeEntry(entry, type, kind)
Bases:
objectBase class for path entries.
- entry: Any
- type: builtins.type
- kind: PyTreeKind
- __add__(other, /)[source]
Join the path entry with another path entry or accessor.
- Return type:
PyTreeAccessor
- __dataclass_fields__ = {'entry': Field(name='entry',type='Any',default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'kind': Field(name='kind',type='PyTreeKind',default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD), 'type': Field(name='type',type='builtins.type',default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD)}
- __dataclass_params__ = _DataclassParams(init=True,repr=False,eq=False,order=False,unsafe_hash=False,frozen=True,match_args=True,kw_only=False,slots=True,weakref_slot=False)
- __delattr__(name)
Implement delattr(self, name).
- __init__(entry, type, kind)
- __match_args__ = ('entry', 'type', 'kind')
- __replace__(**changes)
- __setattr__(name, value)
Implement setattr(self, name, value).
- class optree.GetAttrEntry(entry, type, kind)
Bases:
PyTreeEntryA generic path entry class for nodes that access their children by
__getattr__().- property name: str
Get the attribute name.
- class optree.GetItemEntry(entry, type, kind)
Bases:
PyTreeEntryA generic path entry class for nodes that access their children by
__getitem__().
- class optree.FlattenedEntry(entry, type, kind)
Bases:
PyTreeEntryA fallback path entry class for flattened objects.
- class optree.AutoEntry(entry, type, kind)
Bases:
PyTreeEntryA generic path entry class that determines the entry type on creation automatically.
Create a new path entry.
- static __new__(cls, /, entry, type, kind)[source]
Create a new path entry.
- Return type:
PyTreeEntry
- class optree.SequenceEntry(entry, type, kind)
Bases:
GetItemEntry,Generic[_T_co]A path entry class for sequences.
- property index: int
Get the index.
- class optree.MappingEntry(entry, type, kind)
Bases:
GetItemEntry,Generic[_KT_co,_VT_co]A path entry class for mappings.
- property key: _KT_co
Get the key.
- class optree.NamedTupleEntry(entry, type, kind)
Bases:
SequenceEntry[_T]A path entry class for namedtuple objects.
- property field: str
Get the field name.
- class optree.StructSequenceEntry(entry, type, kind)
Bases:
SequenceEntry[_T]A path entry class for PyStructSequence objects.
- property field: str
Get the field name.
- class optree.DataclassEntry(entry, type, kind)
Bases:
GetAttrEntryA path entry class for dataclasses.
- property field: str
Get the field name.
- property name: str
Get the attribute name.
- class optree.PyTreeAccessor(path: Iterable[PyTreeEntry] = ())
Bases:
tuple[PyTreeEntry, …]A path class for PyTrees.
Create a new accessor instance.
- __add__(other, /)[source]
Join the accessor with another path entry or accessor.
- Return type:
Self
- __getitem__(index, /)[source]
Get the child path entry or an accessor for a subpath.
- Return type:
PyTreeEntry|Self
- __mul__(value, /)[source]
Repeat the accessor.
- Return type:
Self
- static __new__(cls, /, path=())[source]
Create a new accessor instance.
- Return type:
Self
- __rmul__(value, /)[source]
Repeat the accessor.
- Return type:
Self
- optree.register_pytree_node(cls, /, flatten_func, unflatten_func, *, path_entry_type=<class 'optree.AutoEntry'>, namespace)[source]
Extend the set of types that are considered internal nodes in pytrees.
See also
register_pytree_node_class()andunregister_pytree_node().The
namespaceargument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.Warning
For safety reasons, a
namespacemust be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.- Parameters:
cls (type) – A Python type to treat as an internal pytree node.
flatten_func (callable) – A function to be used during flattening, taking an instance of
clsand returning a triple or optionally a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable metadata to be stored in the treespec and to be passed to theunflatten_func, and (3) (optional) an iterable for the tree path entries to the corresponding children. If the entries are not provided or given byNone, then range(len(children)) will be used.unflatten_func (callable) – A function taking two arguments: the metadata that was returned by
flatten_funcand stored in the treespec, and the unflattened children. The function should return an instance ofcls.path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default:
AutoEntry)namespace (str) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.
- Return type:
type[Collection[TypeVar(T)]]- Returns:
The same type as the input
cls.- Raises:
TypeError – If the input type is not a class.
TypeError – If the path entry class is not a subclass of
PyTreeEntry.TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is already registered in the registry.
Added in version 0.12.0: The
path_entry_typeargument to specify the path entry type used inPyTreeSpec.accessors()andtree_flatten_with_accessor(). If not provided,AutoEntrywill be used.Examples
>>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='set', ... ) <class 'set'>
>>> # Register a custom type into a namespace with accessor support >>> import types >>> # This can be whatever your container type is. >>> class MyContainer(types.SimpleNamespace): ... pass >>> # (Optional) Define a custom path entry type for accessor support. >>> # Here we showcase how to define one. In practice, you can use the built-in ``GetAttrEntry``. >>> class MyContainerEntry(PyTreeEntry): ... def __call__(self, obj): ... return getattr(obj, self.entry) ... def codify(self, node=''): ... return f'{node}.{self.entry}' >>> register_pytree_node( ... MyContainer, ... flatten_func=lambda ct: ( ... list(vars(ct).values()), ... list(vars(ct).keys()), ... list(vars(ct).keys()), ... ), ... unflatten_func=lambda keys, values: MyContainer(**dict(zip(keys, values))), ... path_entry_type=MyContainerEntry, ... namespace='mycontainer', ... ) <class '...MyContainer'>
>>> tree = {'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}
>>> # Flatten without specifying the namespace >>> tree_flatten(tree) # `MyContainer`s are leaf nodes ([MyContainer(lr=0.01, momentum=0.9), 1000], PyTreeSpec({'config': *, 'steps': *}))
>>> # Flatten with the namespace >>> leaves, treespec = tree_flatten(tree, namespace='mycontainer') >>> leaves, treespec ([0.01, 0.9, 1000], PyTreeSpec({'config': CustomTreeNode(MyContainer[['lr', 'momentum']], [*, *]), 'steps': *}, namespace='mycontainer'))
>>> # Custom ``entries`` are defined as attribute names >>> tree_paths(tree, namespace='mycontainer') [('config', 'lr'), ('config', 'momentum'), ('steps',)]
>>> # Custom path entry type defines the pytree access behavior >>> accessors = tree_accessors(tree, namespace='mycontainer') >>> accessors[0].codify() "*['config'].lr" >>> accessors[0](tree) 0.01
>>> # Unflatten back to a copy of the original object >>> tree_unflatten(treespec, leaves) {'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}
- optree.register_pytree_node_class(cls=None, /, *, path_entry_type=None, namespace=None)[source]
Extend the set of types that are considered internal nodes in pytrees.
See also
register_pytree_node()andunregister_pytree_node().The
namespaceargument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.Warning
For safety reasons, a
namespacemust be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.- Parameters:
cls (type, optional) – A Python type to treat as an internal pytree node.
path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default:
AutoEntry)namespace (str, optional) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.
- Return type:
TypeVar(CustomTreeNodeType, bound=type[CustomTreeNode]) |Callable[[TypeVar(CustomTreeNodeType, bound=type[CustomTreeNode])],TypeVar(CustomTreeNodeType, bound=type[CustomTreeNode])]- Returns:
The same type as the input
clsif the argument presents. Otherwise, return a decorator function that registers the class as a pytree node.- Raises:
TypeError – If the path entry class is not a subclass of
PyTreeEntry.TypeError – If the namespace is not a string.
TypeError – If the class does not define the required method pairs.
ValueError – If the namespace is an empty string.
ValueError – If the type is already registered in the registry.
Added in version 0.12.0: The
TREE_PATH_ENTRY_TYPEclass variable to specify the path entry type used inPyTreeSpec.accessors()andtree_flatten_with_accessor(). If not provided,AutoEntrywill be used.Added in version 0.18.0: Previously, this function looked for methods named
tree_flattenandtree_unflattenfor the given class. Since version 0.18.0, it prefers methods named__tree_flatten__and__tree_unflatten__instead. The old method names are still supported for backward compatibility, but it is recommended to use the new method names. The method resolution follows this priority: 1. If both__tree_flatten__and__tree_unflatten__are defined, use them directly. 2. If bothtree_flattenandtree_unflattenare defined, wrap them as dunder methods. 3. If neither complete pair is available, raise aTypeErrorsuggesting the new method names.This function is a thin wrapper around
register_pytree_node(), and provides a class-oriented interface:@register_pytree_node_class(namespace='foo') class Special: TREE_PATH_ENTRY_TYPE = GetAttrEntry def __init__(self, x, y): self.x = x self.y = y def __tree_flatten__(self): return ((self.x, self.y), None, ('x', 'y')) @classmethod def __tree_unflatten__(cls, metadata, children): return cls(*children) @register_pytree_node_class('mylist') class MyList(UserList): TREE_PATH_ENTRY_TYPE = SequenceEntry def __tree_flatten__(self): return self.data, None, None @classmethod def __tree_unflatten__(cls, metadata, children): return cls(*children) # Legacy style (still supported but not recommended) @register_pytree_node_class(namespace='legacy') class LegacyStyleMyList(UserList): def tree_flatten(self): # Implementation automatically wrapped as __tree_flatten__ return self.data, None, None @classmethod def tree_unflatten(cls, metadata, children): # Implementation automatically wrapped as __tree_unflatten__ return cls(*children)
- optree.unregister_pytree_node(cls, /, *, namespace)[source]
Remove a type from the pytree node registry.
See also
register_pytree_node()andregister_pytree_node_class().This function is the inverse operation of function
register_pytree_node().- Parameters:
- Return type:
PyTreeNodeRegistryEntry- Returns:
The removed registry entry.
- Raises:
TypeError – If the input type is not a class.
TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is a built-in type that cannot be unregistered.
ValueError – If the type is not found in the registry.
Examples
>>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='temp', ... ) <class 'set'>
>>> # Unregister the Python type >>> unregister_pytree_node(set, namespace='temp')
- optree.dict_insertion_ordered(mode, /, *, namespace)[source]
Context manager to temporarily set the dictionary sorting mode.
This context manager is used to temporarily set the dictionary sorting mode for a specific namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary should be sorted or keep the insertion order when flattening a pytree.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten(tree) ( [1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) ) >>> with dict_insertion_ordered(True, namespace='some-namespace'): ... tree_flatten(tree, namespace='some-namespace') ( [2, 3, 4, 1, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace') )
Warning
The dictionary sorting mode is a global setting and is not thread-safe. It is recommended to use this context manager in a single-threaded environment.
- class optree.PyTreeSpec
Bases:
pybind11_objectRepresenting the structure of the pytree.
- __eq__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool
Test for equality to another object.
- __ge__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool
Test for this treespec is a suffix of another object.
- __gt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool
Test for this treespec is a strict suffix of another object.
- __hash__(self: optree.PyTreeSpec, /) int
Return the hash of the treespec.
- __init__(*args, **kwargs)
- __le__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool
Test for this treespec is a prefix of another object.
- __len__(self: optree.PyTreeSpec, /) int
Number of leaves in the tree.
- __lt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool
Test for this treespec is a strict prefix of another object.
- __ne__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool
Test for inequality to another object.
- accessors(self: optree.PyTreeSpec, /) list[object]
Return a list of accessors to the leaves in the treespec.
- broadcast_to_common_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) optree.PyTreeSpec
Broadcast to the common suffix of this treespec and other treespec.
- child(self: optree.PyTreeSpec, index: SupportsInt | SupportsIndex, /) optree.PyTreeSpec
Return the treespec for the child at the given index.
- children(self: optree.PyTreeSpec, /) list[optree.PyTreeSpec]
Return a list of treespecs for the children.
- compose(self: optree.PyTreeSpec, inner: optree.PyTreeSpec, /) optree.PyTreeSpec
Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.
- entries(self: optree.PyTreeSpec, /) list
Return a list of one-level entries to the children.
- entry(self: optree.PyTreeSpec, index: SupportsInt | SupportsIndex, /) object
Return the entry at the given index.
- flatten_up_to(self: optree.PyTreeSpec, tree: object, /) list
Flatten the subtrees in
treeup to the structure of this treespec and return a list of subtrees.
- is_leaf(self: optree.PyTreeSpec, /, *, strict: bool = True) bool
Test whether the treespec represents a leaf.
- is_one_level(self: optree.PyTreeSpec, /) bool
Test whether the treespec represents a one-level tree.
- is_prefix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /, *, strict: bool = False) bool
Test whether this treespec is a prefix of the given treespec.
- is_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /, *, strict: bool = False) bool
Test whether this treespec is a suffix of the given treespec.
- property kind
The kind of the root node.
- property namespace
The registry namespace used to resolve the custom pytree node types.
- property none_is_leaf
Whether to treat None as a leaf. If false, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list.
- property num_children
Number of children of the root node. Note that a leaf is also a node but has no children.
- property num_leaves
Number of leaves in the tree.
- property num_nodes
Number of nodes in the tree. Note that a leaf is also a node but has no children.
- one_level(self: optree.PyTreeSpec, /) optree.PyTreeSpec | None
Return the one-level structure of the root node. Return None if the root node represents a leaf.
- paths(self: optree.PyTreeSpec, /) list[tuple]
Return a list of paths to the leaves of the treespec.
- transform(self: optree.PyTreeSpec, /, f_node: collections.abc.Callable | None = None, f_leaf: collections.abc.Callable | None = None) optree.PyTreeSpec
Transform the pytree structure by applying
f_node(nodespec)at nodes andf_leaf(leafspec)at leaves.
- traverse(self: optree.PyTreeSpec, leaves: collections.abc.Iterable, /, f_node: collections.abc.Callable | None = None, f_leaf: collections.abc.Callable | None = None) object
Walk over the pytree structure, calling
f_leaf(leaf)at leaves, andf_node(node)at reconstructed non-leaf nodes.
- property type
The type of the root node. Return None if the root node is a leaf.
- unflatten(self: optree.PyTreeSpec, leaves: collections.abc.Iterable, /) object
Reconstruct a pytree from the leaves.
- walk(self: optree.PyTreeSpec, leaves: collections.abc.Iterable, /, f_node: collections.abc.Callable | None = None, f_leaf: collections.abc.Callable | None = None) object
Walk over the pytree structure, calling
f_leaf(leaf)at leaves, andf_node(node_type, node_data, children)at non-leaf nodes.
- optree.PyTreeDef
alias of
PyTreeSpec
- class optree.PyTreeKind(*values)
Bases:
IntEnumThe kind of a pytree node.
- __format__(format_spec, /)
Convert to a string according to format_spec.
- __new__(value)
- CUSTOM = 0
- LEAF = 1
- NONE = 2
- TUPLE = 3
- LIST = 4
- DICT = 5
- NAMEDTUPLE = 6
- ORDEREDDICT = 7
- DEFAULTDICT = 8
- DEQUE = 9
- STRUCTSEQUENCE = 10
- NUM_KINDS = 11
- class optree.PyTree[source]
Bases:
Generic[T]Generic PyTree type.
>>> import torch >>> TensorTree = PyTree[torch.Tensor] >>> TensorTree typing.Union[torch.Tensor, tuple[ForwardRef('PyTree[torch.Tensor]'), ...], list[ForwardRef('PyTree[torch.Tensor]')], dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')], collections.deque[ForwardRef('PyTree[torch.Tensor]')], optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
Prohibit instantiation.
- classmethod __class_getitem__(cls, item)[source]
Instantiate a PyTree type with the given type.
- Return type:
- classmethod __init_subclass__(*args, **kwargs)[source]
Prohibit subclassing.
- Return type:
Never
- __iter__()[source]
Emulate collection-like behavior.
- static __new__(cls, /)[source]
Prohibit instantiation.
- Return type:
Never
- get(key, default=None, /)[source]
Emulate mapping-like behavior.
- values()[source]
Emulate mapping-like behavior.
- Return type:
ValuesView[PyTree[TypeVar(T)]]
- class optree.PyTreeTypeVar(name: str, param: type | TypeAliasType)[source]
Bases:
objectType variable for PyTree.
>>> import torch >>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor) >>> TensorTree typing.Union[torch.Tensor, tuple[ForwardRef('TensorTree'), ...], list[ForwardRef('TensorTree')], dict[typing.Any, ForwardRef('TensorTree')], collections.deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
Instantiate a PyTree type variable with the given name and parameter.
- classmethod __init_subclass__(*args, **kwargs)[source]
Prohibit subclassing.
- Return type:
Never
- static __new__(cls, /, name, param)[source]
Instantiate a PyTree type variable with the given name and parameter.
- Return type:
- class optree.CustomTreeNode(*args, **kwargs)[source]
Bases:
Protocol[T]The abstract base class for custom pytree nodes.
- __init__(*args, **kwargs)
- __tree_flatten__()[source]
Flatten the custom pytree node into children and metadata.
- classmethod __tree_unflatten__(metadata, children, /)[source]
Unflatten the children and metadata into the custom pytree node.
- Return type:
Self
- class optree.FlattenFunc(*args, **kwargs)[source]
Bases:
Protocol[T]The type stub class for flatten functions.
- abstractmethod __call__(container, /)[source]
Flatten the container into children and metadata.
- __init__(*args, **kwargs)
- class optree.UnflattenFunc(*args, **kwargs)[source]
Bases:
Protocol[T]The type stub class for unflatten functions.
- abstractmethod __call__(metadata, children, /)[source]
Unflatten the children and metadata back into the container.
- Return type:
Collection[TypeVar(T)]
- __init__(*args, **kwargs)
- optree.is_namedtuple(obj, /)[source]
Return whether the object is an instance of namedtuple or a subclass of namedtuple.
- Return type:
- optree.is_namedtuple_class(cls, /)[source]
Return whether the class is a subclass of namedtuple.
- Return type:
- optree.is_namedtuple_instance(obj, /)[source]
Return whether the object is an instance of namedtuple.
- Return type:
- optree.namedtuple_fields(obj, /)[source]
Return the field names of a namedtuple.
- optree.is_structseq(obj, /)[source]
Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.
- Return type:
- optree.is_structseq_class(cls, /)[source]
Return whether the class is a class of PyStructSequence.
- Return type:
- optree.is_structseq_instance(obj, /)[source]
Return whether the object is an instance of PyStructSequence.
- Return type: