Typing Support

PyTreeSpec

Representing the structure of the pytree.

PyTreeDef

alias of PyTreeSpec

PyTreeKind(*values)

The kind of a pytree node.

PyTree()

Generic PyTree type.

PyTreeTypeVar(name, param)

Type variable for PyTree.

CustomTreeNode(*args, **kwargs)

The abstract base class for custom pytree nodes.

is_namedtuple(obj, /)

Return whether the object is an instance of namedtuple or a subclass of namedtuple.

is_namedtuple_instance(obj, /)

Return whether the object is an instance of namedtuple.

is_namedtuple_class(cls, /)

Return whether the class is a subclass of namedtuple.

namedtuple_fields(obj, /)

Return the field names of a namedtuple.

is_structseq(obj, /)

Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.

is_structseq_instance(obj, /)

Return whether the object is an instance of PyStructSequence.

is_structseq_class(cls, /)

Return whether the class is a class of PyStructSequence.

structseq_fields(obj, /)

Return the field names of a PyStructSequence.

class optree.PyTreeSpec

Bases: pybind11_object

Representing the structure of the pytree.

__eq__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool

Test for equality to another object.

__ge__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool

Test for this treespec is a suffix of another object.

__gt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool

Test for this treespec is a strict suffix of another object.

__hash__(self: optree.PyTreeSpec, /) int

Return the hash of the treespec.

__init__(*args, **kwargs)
__le__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool

Test for this treespec is a prefix of another object.

__len__(self: optree.PyTreeSpec, /) int

Number of leaves in the tree.

__lt__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool

Test for this treespec is a strict prefix of another object.

__ne__(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) bool

Test for inequality to another object.

accessors(self: optree.PyTreeSpec, /) list[object]

Return a list of accessors to the leaves in the treespec.

broadcast_to_common_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /) optree.PyTreeSpec

Broadcast to the common suffix of this treespec and other treespec.

child(self: optree.PyTreeSpec, index: SupportsInt | SupportsIndex, /) optree.PyTreeSpec

Return the treespec for the child at the given index.

children(self: optree.PyTreeSpec, /) list[optree.PyTreeSpec]

Return a list of treespecs for the children.

compose(self: optree.PyTreeSpec, inner: optree.PyTreeSpec, /) optree.PyTreeSpec

Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.

entries(self: optree.PyTreeSpec, /) list

Return a list of one-level entries to the children.

entry(self: optree.PyTreeSpec, index: SupportsInt | SupportsIndex, /) object

Return the entry at the given index.

flatten_up_to(self: optree.PyTreeSpec, tree: object, /) list

Flatten the subtrees in tree up to the structure of this treespec and return a list of subtrees.

is_leaf(self: optree.PyTreeSpec, /, *, strict: bool = True) bool

Test whether the treespec represents a leaf.

is_one_level(self: optree.PyTreeSpec, /) bool

Test whether the treespec represents a one-level tree.

is_prefix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /, *, strict: bool = False) bool

Test whether this treespec is a prefix of the given treespec.

is_suffix(self: optree.PyTreeSpec, other: optree.PyTreeSpec, /, *, strict: bool = False) bool

Test whether this treespec is a suffix of the given treespec.

property kind

The kind of the root node.

property namespace

The registry namespace used to resolve the custom pytree node types.

property none_is_leaf

Whether to treat None as a leaf. If false, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list.

property num_children

Number of children of the root node. Note that a leaf is also a node but has no children.

property num_leaves

Number of leaves in the tree.

property num_nodes

Number of nodes in the tree. Note that a leaf is also a node but has no children.

one_level(self: optree.PyTreeSpec, /) optree.PyTreeSpec | None

Return the one-level structure of the root node. Return None if the root node represents a leaf.

paths(self: optree.PyTreeSpec, /) list[tuple]

Return a list of paths to the leaves of the treespec.

transform(self: optree.PyTreeSpec, /, f_node: collections.abc.Callable | None = None, f_leaf: collections.abc.Callable | None = None) optree.PyTreeSpec

Transform the pytree structure by applying f_node(nodespec) at nodes and f_leaf(leafspec) at leaves.

traverse(self: optree.PyTreeSpec, leaves: collections.abc.Iterable, /, f_node: collections.abc.Callable | None = None, f_leaf: collections.abc.Callable | None = None) object

Walk over the pytree structure, calling f_leaf(leaf) at leaves, and f_node(node) at reconstructed non-leaf nodes.

property type

The type of the root node. Return None if the root node is a leaf.

unflatten(self: optree.PyTreeSpec, leaves: collections.abc.Iterable, /) object

Reconstruct a pytree from the leaves.

walk(self: optree.PyTreeSpec, leaves: collections.abc.Iterable, /, f_node: collections.abc.Callable | None = None, f_leaf: collections.abc.Callable | None = None) object

Walk over the pytree structure, calling f_leaf(leaf) at leaves, and f_node(node_type, node_data, children) at non-leaf nodes.

optree.PyTreeDef

alias of PyTreeSpec

class optree.PyTreeKind(*values)

Bases: IntEnum

The kind of a pytree node.

__format__(format_spec, /)

Convert to a string according to format_spec.

__new__(value)
CUSTOM = 0
LEAF = 1
NONE = 2
TUPLE = 3
LIST = 4
DICT = 5
NAMEDTUPLE = 6
ORDEREDDICT = 7
DEFAULTDICT = 8
DEQUE = 9
STRUCTSEQUENCE = 10
NUM_KINDS = 11
class optree.PyTree[source]

Bases: Generic[T]

Generic PyTree type.

>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
             tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
             list[ForwardRef('PyTree[torch.Tensor]')],
             dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
             collections.deque[ForwardRef('PyTree[torch.Tensor]')],
             optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]

Prohibit instantiation.

classmethod __class_getitem__(cls, item)[source]

Instantiate a PyTree type with the given type.

Return type:

TypeAliasType

static __new__(cls, /)[source]

Prohibit instantiation.

Return type:

Never

classmethod __init_subclass__(*args, **kwargs)[source]

Prohibit subclassing.

Return type:

Never

__getitem__(key, /)[source]

Emulate collection-like behavior.

Return type:

PyTree [TypeVar(T)]

__getattr__(name, /)[source]

Emulate dataclass-like behavior.

Return type:

PyTree [TypeVar(T)]

__contains__(key, /)[source]

Emulate collection-like behavior.

Return type:

bool

__len__()[source]

Emulate collection-like behavior.

Return type:

int

__iter__()[source]

Emulate collection-like behavior.

Return type:

Iterator[PyTree [TypeVar(T)] | Any]

index(key, /)[source]

Emulate sequence-like behavior.

Return type:

int

count(key, /)[source]

Emulate sequence-like behavior.

Return type:

int

get(key, default=None, /)[source]

Emulate mapping-like behavior.

Return type:

PyTree [TypeVar(T)] | TypeVar(S) | None

keys()[source]

Emulate mapping-like behavior.

Return type:

KeysView[Any]

values()[source]

Emulate mapping-like behavior.

Return type:

ValuesView[PyTree [TypeVar(T)]]

items()[source]

Emulate mapping-like behavior.

Return type:

ItemsView[Any, PyTree [TypeVar(T)]]

optree.PyTreeTypeVar(name: str, param: type | TypeAliasType) TypeAliasType[source]

Type variable for PyTree.

>>> import torch
>>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
>>> TensorTree
typing.Union[torch.Tensor,
             tuple[ForwardRef('TensorTree'), ...],
             list[ForwardRef('TensorTree')],
             dict[typing.Any, ForwardRef('TensorTree')],
             collections.deque[ForwardRef('TensorTree')],
             optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
class optree.CustomTreeNode(*args, **kwargs)[source]

Bases: Protocol[T]

The abstract base class for custom pytree nodes.

__tree_flatten__()[source]

Flatten the custom pytree node into children and metadata.

Return type:

tuple[Iterable[TypeVar(T)], Hashable | None] | tuple[Iterable[TypeVar(T)], Hashable | None, Iterable[Any] | None]

classmethod __tree_unflatten__(metadata, children, /)[source]

Unflatten the children and metadata into the custom pytree node.

Return type:

Self

__init__(*args, **kwargs)
optree.is_namedtuple(obj, /)[source]

Return whether the object is an instance of namedtuple or a subclass of namedtuple.

Return type:

bool

optree.is_namedtuple_instance(obj, /)[source]

Return whether the object is an instance of namedtuple.

Return type:

bool

optree.is_namedtuple_class(cls, /)[source]

Return whether the class is a subclass of namedtuple.

Return type:

bool

optree.namedtuple_fields(obj, /)[source]

Return the field names of a namedtuple.

Return type:

tuple[str, ...]

optree.is_structseq(obj, /)[source]

Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.

Return type:

bool

optree.is_structseq_instance(obj, /)[source]

Return whether the object is an instance of PyStructSequence.

Return type:

bool

optree.is_structseq_class(cls, /)[source]

Return whether the class is a class of PyStructSequence.

Return type:

bool

optree.structseq_fields(obj, /)[source]

Return the field names of a PyStructSequence.

Return type:

tuple[str, ...]