Typing Support
Representing the structure of the pytree. |
|
alias of |
|
|
The kind of a pytree node. |
|
Generic PyTree type. |
|
Type variable for PyTree. |
|
The abstract base class for custom pytree nodes. |
|
Return whether the object is an instance of namedtuple or a subclass of namedtuple. |
Return whether the object is an instance of namedtuple. |
|
|
Return whether the class is a subclass of namedtuple. |
|
Return the field names of a namedtuple. |
|
Return whether the object is an instance of PyStructSequence or a class of PyStructSequence. |
Return whether the object is an instance of PyStructSequence. |
|
|
Return whether the object is a class of PyStructSequence. |
|
Return the field names of a PyStructSequence. |
- class optree.PyTreeSpec
Bases:
pybind11_object
Representing the structure of the pytree.
- __annotations__ = {}
- __delattr__(name, /)
Implement delattr(self, name).
- __eq__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool
Test for equality to another object.
- __ge__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool
Test for this treespec is a suffix of another object.
- __getattribute__(name, /)
Return getattr(self, name).
- __getstate__(self: optree.PyTreeSpec) object
- __gt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool
Test for this treespec is a strict suffix of another object.
- __hash__(self: optree.PyTreeSpec) int
Return the hash of the treespec.
- __init__(*args, **kwargs)
- __le__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool
Test for this treespec is a prefix of another object.
- __len__(self: optree.PyTreeSpec) int
Number of leaves in the tree.
- __lt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool
Test for this treespec is a strict prefix of another object.
- __ne__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool
Test for inequality to another object.
- __new__(**kwargs)
- __setattr__(name, value, /)
Implement setattr(self, name, value).
- __setstate__(self: optree.PyTreeSpec, state: object) None
Serialization support for PyTreeSpec.
- accessors(self: optree.PyTreeSpec) list[object]
Return a list of accessors to the leaves in the treespec.
- broadcast_to_common_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec) optree.PyTreeSpec
Broadcast to the common suffix of this treespec and other treespec.
- child(self: optree.PyTreeSpec, index: int) optree.PyTreeSpec
Return the treespec for the child at the given index.
- children(self: optree.PyTreeSpec) list[optree.PyTreeSpec]
Return a list of treespecs for the children.
- compose(self: optree.PyTreeSpec, inner_treespec: optree.PyTreeSpec) optree.PyTreeSpec
Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.
- entries(self: optree.PyTreeSpec) list
Return a list of one-level entries to the children.
- entry(self: optree.PyTreeSpec, index: int) object
Return the entry at the given index.
- flatten_up_to(self: optree.PyTreeSpec, full_tree: object) list
Flatten the subtrees in
full_tree
up to the structure of this treespec and return a list of subtrees.
- is_leaf(self: optree.PyTreeSpec, strict: bool = True) bool
Test whether the current node is a leaf.
- is_prefix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, strict: bool = False) bool
Test whether this treespec is a prefix of the given treespec.
- is_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, strict: bool = False) bool
Test whether this treespec is a suffix of the given treespec.
- property kind
The kind of the current node.
- property namespace
The registry namespace used to resolve the custom pytree node types.
- property none_is_leaf
Whether to treat None as a leaf. If false, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list.
- property num_children
Number of children in the current node. Note that a leaf is also a node but has no children.
- property num_leaves
Number of leaves in the tree.
- property num_nodes
Number of nodes in the tree. Note that a leaf is also a node but has no children.
- paths(self: optree.PyTreeSpec) list[tuple]
Return a list of paths to the leaves of the treespec.
- property type
The type of the current node. Return None if the current node is a leaf.
- unflatten(self: optree.PyTreeSpec, leaves: Iterable) object
Reconstruct a pytree from the leaves.
- walk(self: optree.PyTreeSpec, f_node: Callable, f_leaf: object, leaves: Iterable) object
Walk over the pytree structure, calling
f_node(children, node_data)
at nodes, andf_leaf(leaf)
at leaves.
- optree.PyTreeDef
alias of
PyTreeSpec
- class optree.PyTreeKind(self: optree._C.PyTreeKind, value: int)
Bases:
pybind11_object
The kind of a pytree node.
Members:
CUSTOM : A custom type.
LEAF : An opaque leaf node.
NONE : None.
TUPLE : A tuple.
LIST : A list.
DICT : A dict.
NAMEDTUPLE : A collections.namedtuple.
ORDEREDDICT : A collections.OrderedDict.
DEFAULTDICT : A collections.defaultdict.
DEQUE : A collections.deque.
STRUCTSEQUENCE : A PyStructSequence.
- CUSTOM = <PyTreeKind.CUSTOM: 0>
- DEFAULTDICT = <PyTreeKind.DEFAULTDICT: 8>
- DEQUE = <PyTreeKind.DEQUE: 9>
- DICT = <PyTreeKind.DICT: 5>
- LEAF = <PyTreeKind.LEAF: 1>
- LIST = <PyTreeKind.LIST: 4>
- NAMEDTUPLE = <PyTreeKind.NAMEDTUPLE: 6>
- NONE = <PyTreeKind.NONE: 2>
- ORDEREDDICT = <PyTreeKind.ORDEREDDICT: 7>
- STRUCTSEQUENCE = <PyTreeKind.STRUCTSEQUENCE: 10>
- TUPLE = <PyTreeKind.TUPLE: 3>
- __annotations__ = {}
- __delattr__(name, /)
Implement delattr(self, name).
- __ge__(value, /)
Return self>=value.
- __getattribute__(name, /)
Return getattr(self, name).
- __gt__(value, /)
Return self>value.
- __le__(value, /)
Return self<=value.
- __lt__(value, /)
Return self<value.
- __members__ = {'CUSTOM': <PyTreeKind.CUSTOM: 0>, 'DEFAULTDICT': <PyTreeKind.DEFAULTDICT: 8>, 'DEQUE': <PyTreeKind.DEQUE: 9>, 'DICT': <PyTreeKind.DICT: 5>, 'LEAF': <PyTreeKind.LEAF: 1>, 'LIST': <PyTreeKind.LIST: 4>, 'NAMEDTUPLE': <PyTreeKind.NAMEDTUPLE: 6>, 'NONE': <PyTreeKind.NONE: 2>, 'ORDEREDDICT': <PyTreeKind.ORDEREDDICT: 7>, 'STRUCTSEQUENCE': <PyTreeKind.STRUCTSEQUENCE: 10>, 'TUPLE': <PyTreeKind.TUPLE: 3>}
- __new__(**kwargs)
- __setattr__(name, value, /)
Implement setattr(self, name, value).
- property name
- property value
- class optree.PyTree[source]
Bases:
Generic
[T
]Generic PyTree type.
>>> import torch >>> from optree.typing import PyTree >>> TensorTree = PyTree[torch.Tensor] >>> TensorTree typing.Union[torch.Tensor, typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...], typing.List[ForwardRef('PyTree[torch.Tensor]')], typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')], typing.Deque[ForwardRef('PyTree[torch.Tensor]')], optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
Prohibit instantiation.
- classmethod __class_getitem__(cls, item)[source]
Instantiate a PyTree type with the given type.
- Return type:
TypeAlias
- __annotations__ = {}
- __orig_bases__ = (typing.Generic[~T],)
- __parameters__ = (~T,)
- optree.PyTreeTypeVar(name: str, param: type) TypeAlias [source]
Type variable for PyTree.
>>> import torch >>> from optree.typing import PyTreeTypeVar >>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor) >>> TensorTree typing.Union[torch.Tensor, typing.Tuple[ForwardRef('TensorTree'), ...], typing.List[ForwardRef('TensorTree')], typing.Dict[typing.Any, ForwardRef('TensorTree')], typing.Deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
- class optree.CustomTreeNode(*args, **kwargs)[source]
Bases:
Protocol
[T
]The abstract base class for custom pytree nodes.
- classmethod tree_unflatten(metadata, children)[source]
Unflatten the children and auxiliary data into the custom pytree node.
- Return type:
- __abstractmethods__ = frozenset({})
- __annotations__ = {}
- __init__(*args, **kwargs)
- __non_callable_proto_members__ = {}
- __orig_bases__ = (typing_extensions.Protocol[~T],)
- __parameters__ = (~T,)
- __protocol_attrs__ = {'tree_flatten', 'tree_unflatten'}
- classmethod __subclasshook__(other)
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
- optree.is_namedtuple(obj: object) bool
Return whether the object is an instance of namedtuple or a subclass of namedtuple.
- optree.is_namedtuple_instance(obj: object) bool
Return whether the object is an instance of namedtuple.
- optree.is_namedtuple_class(cls: object) bool
Return whether the class is a subclass of namedtuple.
- optree.is_structseq(obj: object) bool
Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.
- optree.is_structseq_instance(obj: object) bool
Return whether the object is an instance of PyStructSequence.