Typing Support

PyTreeSpec

Representing the structure of the pytree.

PyTreeDef

alias of PyTreeSpec

PyTreeKind(self, value)

The kind of a pytree node.

PyTree()

Generic PyTree type.

PyTreeTypeVar(name, param)

Type variable for PyTree.

CustomTreeNode(*args, **kwargs)

The abstract base class for custom pytree nodes.

is_namedtuple(obj)

Return whether the object is an instance of namedtuple or a subclass of namedtuple.

is_namedtuple_instance(obj)

Return whether the object is an instance of namedtuple.

is_namedtuple_class(cls)

Return whether the class is a subclass of namedtuple.

namedtuple_fields(obj)

Return the field names of a namedtuple.

is_structseq(obj)

Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.

is_structseq_instance(obj)

Return whether the object is an instance of PyStructSequence.

is_structseq_class(cls)

Return whether the object is a class of PyStructSequence.

structseq_fields(obj)

Return the field names of a PyStructSequence.

class optree.PyTreeSpec

Bases: pybind11_object

Representing the structure of the pytree.

__annotations__ = {}
__delattr__(name, /)

Implement delattr(self, name).

__eq__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool

Test for equality to another object.

__ge__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool

Test for this treespec is a suffix of another object.

__getattribute__(name, /)

Return getattr(self, name).

__getstate__(self: optree.PyTreeSpec) object
__gt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool

Test for this treespec is a strict suffix of another object.

__hash__(self: optree.PyTreeSpec) int

Return the hash of the treespec.

__init__(*args, **kwargs)
__le__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool

Test for this treespec is a prefix of another object.

__len__(self: optree.PyTreeSpec) int

Number of leaves in the tree.

__lt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool

Test for this treespec is a strict prefix of another object.

__ne__(self: optree.PyTreeSpec, other: optree.PyTreeSpec) bool

Test for inequality to another object.

__new__(**kwargs)
__setattr__(name, value, /)

Implement setattr(self, name, value).

__setstate__(self: optree.PyTreeSpec, state: object) None

Serialization support for PyTreeSpec.

accessors(self: optree.PyTreeSpec) list[object]

Return a list of accessors to the leaves in the treespec.

broadcast_to_common_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec) optree.PyTreeSpec

Broadcast to the common suffix of this treespec and other treespec.

child(self: optree.PyTreeSpec, index: int) optree.PyTreeSpec

Return the treespec for the child at the given index.

children(self: optree.PyTreeSpec) list[optree.PyTreeSpec]

Return a list of treespecs for the children.

compose(self: optree.PyTreeSpec, inner_treespec: optree.PyTreeSpec) optree.PyTreeSpec

Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.

entries(self: optree.PyTreeSpec) list

Return a list of one-level entries to the children.

entry(self: optree.PyTreeSpec, index: int) object

Return the entry at the given index.

flatten_up_to(self: optree.PyTreeSpec, full_tree: object) list

Flatten the subtrees in full_tree up to the structure of this treespec and return a list of subtrees.

is_leaf(self: optree.PyTreeSpec, strict: bool = True) bool

Test whether the current node is a leaf.

is_prefix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, strict: bool = False) bool

Test whether this treespec is a prefix of the given treespec.

is_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, strict: bool = False) bool

Test whether this treespec is a suffix of the given treespec.

property kind

The kind of the current node.

property namespace

The registry namespace used to resolve the custom pytree node types.

property none_is_leaf

Whether to treat None as a leaf. If false, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list.

property num_children

Number of children in the current node. Note that a leaf is also a node but has no children.

property num_leaves

Number of leaves in the tree.

property num_nodes

Number of nodes in the tree. Note that a leaf is also a node but has no children.

paths(self: optree.PyTreeSpec) list[tuple]

Return a list of paths to the leaves of the treespec.

property type

The type of the current node. Return None if the current node is a leaf.

unflatten(self: optree.PyTreeSpec, leaves: Iterable) object

Reconstruct a pytree from the leaves.

walk(self: optree.PyTreeSpec, f_node: Callable, f_leaf: object, leaves: Iterable) object

Walk over the pytree structure, calling f_node(children, node_data) at nodes, and f_leaf(leaf) at leaves.

optree.PyTreeDef

alias of PyTreeSpec

class optree.PyTreeKind(self: optree._C.PyTreeKind, value: int)

Bases: pybind11_object

The kind of a pytree node.

Members:

CUSTOM : A custom type.

LEAF : An opaque leaf node.

NONE : None.

TUPLE : A tuple.

LIST : A list.

DICT : A dict.

NAMEDTUPLE : A collections.namedtuple.

ORDEREDDICT : A collections.OrderedDict.

DEFAULTDICT : A collections.defaultdict.

DEQUE : A collections.deque.

STRUCTSEQUENCE : A PyStructSequence.

CUSTOM = <PyTreeKind.CUSTOM: 0>
DEFAULTDICT = <PyTreeKind.DEFAULTDICT: 8>
DEQUE = <PyTreeKind.DEQUE: 9>
DICT = <PyTreeKind.DICT: 5>
LEAF = <PyTreeKind.LEAF: 1>
LIST = <PyTreeKind.LIST: 4>
NAMEDTUPLE = <PyTreeKind.NAMEDTUPLE: 6>
NONE = <PyTreeKind.NONE: 2>
ORDEREDDICT = <PyTreeKind.ORDEREDDICT: 7>
STRUCTSEQUENCE = <PyTreeKind.STRUCTSEQUENCE: 10>
TUPLE = <PyTreeKind.TUPLE: 3>
__annotations__ = {}
__delattr__(name, /)

Implement delattr(self, name).

__eq__(self: object, other: object) bool
__ge__(value, /)

Return self>=value.

__getattribute__(name, /)

Return getattr(self, name).

__getstate__(self: object) int
__gt__(value, /)

Return self>value.

__hash__(self: object) int
__index__(self: optree._C.PyTreeKind) int
__init__(self: optree._C.PyTreeKind, value: int) None
__int__(self: optree._C.PyTreeKind) int
__le__(value, /)

Return self<=value.

__lt__(value, /)

Return self<value.

__members__ = {'CUSTOM': <PyTreeKind.CUSTOM: 0>, 'DEFAULTDICT': <PyTreeKind.DEFAULTDICT: 8>, 'DEQUE': <PyTreeKind.DEQUE: 9>, 'DICT': <PyTreeKind.DICT: 5>, 'LEAF': <PyTreeKind.LEAF: 1>, 'LIST': <PyTreeKind.LIST: 4>, 'NAMEDTUPLE': <PyTreeKind.NAMEDTUPLE: 6>, 'NONE': <PyTreeKind.NONE: 2>, 'ORDEREDDICT': <PyTreeKind.ORDEREDDICT: 7>, 'STRUCTSEQUENCE': <PyTreeKind.STRUCTSEQUENCE: 10>, 'TUPLE': <PyTreeKind.TUPLE: 3>}
__ne__(self: object, other: object) bool
__new__(**kwargs)
__setattr__(name, value, /)

Implement setattr(self, name, value).

__setstate__(self: optree._C.PyTreeKind, state: int) None
property name
property value
class optree.PyTree[source]

Bases: Generic[T]

Generic PyTree type.

>>> import torch
>>> from optree.typing import PyTree
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree  
typing.Union[torch.Tensor,
             typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
             typing.List[ForwardRef('PyTree[torch.Tensor]')],
             typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
             typing.Deque[ForwardRef('PyTree[torch.Tensor]')],
             optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]

Prohibit instantiation.

classmethod __class_getitem__(cls, item)[source]

Instantiate a PyTree type with the given type.

Return type:

TypeAlias

static __new__(cls)[source]

Prohibit instantiation.

Return type:

NoReturn

classmethod __init_subclass__(*args, **kwargs)[source]

Prohibit subclassing.

Return type:

NoReturn

__copy__()[source]

Immutable copy.

Return type:

PyTree

__deepcopy__(memo)[source]

Immutable copy.

Return type:

PyTree

__annotations__ = {}
__orig_bases__ = (typing.Generic[~T],)
__parameters__ = (~T,)
optree.PyTreeTypeVar(name: str, param: type) TypeAlias[source]

Type variable for PyTree.

>>> import torch
>>> from optree.typing import PyTreeTypeVar
>>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
>>> TensorTree  
typing.Union[torch.Tensor,
             typing.Tuple[ForwardRef('TensorTree'), ...],
             typing.List[ForwardRef('TensorTree')],
             typing.Dict[typing.Any, ForwardRef('TensorTree')],
             typing.Deque[ForwardRef('TensorTree')],
             optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
class optree.CustomTreeNode(*args, **kwargs)[source]

Bases: Protocol[T]

The abstract base class for custom pytree nodes.

tree_flatten()[source]

Flatten the custom pytree node into children and auxiliary data.

Return type:

tuple[Iterable[TypeVar(T)], Optional[TypeVar(_MetaData, bound= Hashable)]] | tuple[Iterable[TypeVar(T)], Optional[TypeVar(_MetaData, bound= Hashable)], Optional[Iterable[Any]]]

classmethod tree_unflatten(metadata, children)[source]

Unflatten the children and auxiliary data into the custom pytree node.

Return type:

CustomTreeNode[TypeVar(T)]

__abstractmethods__ = frozenset({})
__annotations__ = {}
__init__(*args, **kwargs)
__non_callable_proto_members__ = {}
__orig_bases__ = (typing_extensions.Protocol[~T],)
__parameters__ = (~T,)
__protocol_attrs__ = {'tree_flatten', 'tree_unflatten'}
classmethod __subclasshook__(other)

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

optree.is_namedtuple(obj: object) bool

Return whether the object is an instance of namedtuple or a subclass of namedtuple.

optree.is_namedtuple_instance(obj: object) bool

Return whether the object is an instance of namedtuple.

optree.is_namedtuple_class(cls: object) bool

Return whether the class is a subclass of namedtuple.

optree.namedtuple_fields(obj: object) tuple

Return the field names of a namedtuple.

optree.is_structseq(obj: object) bool

Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.

optree.is_structseq_instance(obj: object) bool

Return whether the object is an instance of PyStructSequence.

optree.is_structseq_class(cls: object) bool

Return whether the object is a class of PyStructSequence.

optree.structseq_fields(obj: object) tuple

Return the field names of a PyStructSequence.