Source code for optree.accessors

# Copyright 2022-2026 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Access support for pytrees."""

from __future__ import annotations

import dataclasses
import sys
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload
from typing_extensions import Self  # Python 3.11+

import optree._C as _C
from optree._C import PyTreeKind


if TYPE_CHECKING:
    import builtins

    from optree.typing import NamedTuple, StructSequence


__all__ = [
    'PyTreeEntry',
    'GetItemEntry',
    'GetAttrEntry',
    'FlattenedEntry',
    'AutoEntry',
    'SequenceEntry',
    'MappingEntry',
    'NamedTupleEntry',
    'StructSequenceEntry',
    'DataclassEntry',
    'PyTreeAccessor',
]


SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {}  # Python 3.10+


@dataclasses.dataclass(init=True, repr=False, eq=False, frozen=True, **SLOTS)
class PyTreeEntry:
    """Base class for path entries."""

    entry: Any
    type: builtins.type
    kind: PyTreeKind

[docs] def __post_init__(self, /) -> None: """Post-initialize the path entry.""" if self.kind == PyTreeKind.LEAF: raise ValueError('Cannot create a leaf path entry.') if self.kind == PyTreeKind.NONE: raise ValueError('Cannot create a path entry for None.')
[docs] def __call__(self, obj: Any, /) -> Any: """Get the child object.""" try: return obj[self.entry] # should be overridden except TypeError as ex: raise TypeError( f'{self.__class__!r} cannot access through {obj!r} via entry {self.entry!r}', ) from ex
[docs] def __add__(self, other: object, /) -> PyTreeAccessor: """Join the path entry with another path entry or accessor.""" if isinstance(other, PyTreeEntry): return PyTreeAccessor((self, other)) if isinstance(other, PyTreeAccessor): return PyTreeAccessor((self, *other)) return NotImplemented
[docs] def __eq__(self, other: object, /) -> bool: """Check if the path entries are equal.""" return isinstance(other, PyTreeEntry) and ( ( self.entry, self.type, self.kind, self.__class__.__call__.__code__.co_code, self.__class__.codify.__code__.co_code, ) == ( other.entry, other.type, other.kind, other.__class__.__call__.__code__.co_code, other.__class__.codify.__code__.co_code, ) )
[docs] def __hash__(self, /) -> int: """Get the hash of the path entry.""" return hash( ( self.entry, self.type, self.kind, self.__class__.__call__.__code__.co_code, self.__class__.codify.__code__.co_code, ), )
[docs] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(entry={self.entry!r}, type={self.type!r})'
[docs] def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}[<flat index {self.entry!r}>]' # should be overridden
del SLOTS _T = TypeVar('_T') _T_co = TypeVar('_T_co', covariant=True) _KT_co = TypeVar('_KT_co', covariant=True) _VT_co = TypeVar('_VT_co', covariant=True) class AutoEntry(PyTreeEntry): """A generic path entry class that determines the entry type on creation automatically.""" __slots__: ClassVar[tuple[()]] = ()
[docs] def __new__( # type: ignore[misc] cls, /, entry: Any, type: builtins.type, # pylint: disable=redefined-builtin kind: PyTreeKind, ) -> PyTreeEntry: """Create a new path entry.""" # pylint: disable-next=import-outside-toplevel from optree.typing import is_namedtuple_class, is_structseq_class if cls is not AutoEntry: # Use the subclass type if the type is explicitly specified return super().__new__(cls) if kind != PyTreeKind.CUSTOM: raise ValueError(f'Cannot create an automatic path entry for PyTreeKind {kind!r}.') # Dispatch the path entry type based on the node type path_entry_type: builtins.type[PyTreeEntry] if is_structseq_class(type): path_entry_type = StructSequenceEntry elif is_namedtuple_class(type): path_entry_type = NamedTupleEntry elif dataclasses.is_dataclass(type): path_entry_type = DataclassEntry elif issubclass(type, Mapping): path_entry_type = MappingEntry elif issubclass(type, Sequence): path_entry_type = SequenceEntry else: path_entry_type = FlattenedEntry if not issubclass(path_entry_type, AutoEntry): # The __init__() method will not be called if the returned instance is not a subtype of # AutoEntry. We should return an initialized instance. Return a fully-initialized # instance of the dispatched type. return path_entry_type(entry, type, kind) # The __init__() method will be called if the returned instance is a subtype of AutoEntry. # We should return an uninitialized instance. The __init__() method will initialize it. # But we will never reach here because the dispatched type is never a subtype of AutoEntry. raise NotImplementedError('Unreachable code.')
class GetItemEntry(PyTreeEntry): """A generic path entry class for nodes that access their children by :meth:`__getitem__`.""" __slots__: ClassVar[tuple[()]] = ()
[docs] def __call__(self, obj: Any, /) -> Any: """Get the child object.""" return obj[self.entry]
[docs] def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}[{self.entry!r}]'
class GetAttrEntry(PyTreeEntry): """A generic path entry class for nodes that access their children by :meth:`__getattr__`.""" __slots__: ClassVar[tuple[()]] = () entry: str @property def name(self, /) -> str: """Get the attribute name.""" return self.entry
[docs] def __call__(self, obj: Any, /) -> Any: """Get the child object.""" for attr in self.name.split('.'): obj = getattr(obj, attr) return obj
[docs] def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}.{self.name}'
class FlattenedEntry(PyTreeEntry): # pylint: disable=too-few-public-methods """A fallback path entry class for flattened objects.""" __slots__: ClassVar[tuple[()]] = () class SequenceEntry(GetItemEntry, Generic[_T_co]): """A path entry class for sequences.""" __slots__: ClassVar[tuple[()]] = () entry: int type: builtins.type[Sequence[_T_co]] @property def index(self, /) -> int: """Get the index.""" return self.entry
[docs] def __call__(self, obj: Sequence[_T_co], /) -> _T_co: """Get the child object.""" return obj[self.index]
[docs] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(index={self.index!r}, type={self.type!r})'
class MappingEntry(GetItemEntry, Generic[_KT_co, _VT_co]): """A path entry class for mappings.""" __slots__: ClassVar[tuple[()]] = () entry: _KT_co type: builtins.type[Mapping[_KT_co, _VT_co]] @property def key(self, /) -> _KT_co: """Get the key.""" return self.entry
[docs] def __call__(self, obj: Mapping[_KT_co, _VT_co], /) -> _VT_co: """Get the child object.""" return obj[self.key]
[docs] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(key={self.key!r}, type={self.type!r})'
class NamedTupleEntry(SequenceEntry[_T]): """A path entry class for namedtuple objects.""" __slots__: ClassVar[tuple[()]] = () entry: int type: builtins.type[NamedTuple[_T]] # type: ignore[type-arg] kind: Literal[PyTreeKind.NAMEDTUPLE] @property def fields(self, /) -> tuple[str, ...]: """Get the field names.""" from optree.typing import namedtuple_fields # pylint: disable=import-outside-toplevel return namedtuple_fields(self.type) @property def field(self, /) -> str: """Get the field name.""" return self.fields[self.entry]
[docs] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})'
[docs] def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}.{self.field}'
class StructSequenceEntry(SequenceEntry[_T]): """A path entry class for PyStructSequence objects.""" __slots__: ClassVar[tuple[()]] = () entry: int type: builtins.type[StructSequence[_T]] kind: Literal[PyTreeKind.STRUCTSEQUENCE] @property def fields(self, /) -> tuple[str, ...]: """Get the field names.""" from optree.typing import structseq_fields # pylint: disable=import-outside-toplevel return structseq_fields(self.type) @property def field(self, /) -> str: """Get the field name.""" return self.fields[self.entry]
[docs] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})'
[docs] def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}.{self.field}'
class DataclassEntry(GetAttrEntry): """A path entry class for dataclasses.""" __slots__: ClassVar[tuple[()]] = () entry: str | int # type: ignore[assignment] @property def fields(self, /) -> tuple[str, ...]: """Get all field names.""" return tuple(f.name for f in dataclasses.fields(self.type)) @property def init_fields(self, /) -> tuple[str, ...]: """Get the init field names.""" return tuple(f.name for f in dataclasses.fields(self.type) if f.init) @property def field(self, /) -> str: """Get the field name.""" if isinstance(self.entry, int): return self.init_fields[self.entry] return self.entry @property def name(self, /) -> str: """Get the attribute name.""" return self.field
[docs] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})'
class PyTreeAccessor(tuple[PyTreeEntry, ...]): """A path class for PyTrees.""" __slots__: ClassVar[tuple[()]] = () @property def path(self, /) -> tuple[Any, ...]: """Get the path of the accessor.""" return tuple(e.entry for e in self)
[docs] def __new__(cls, /, path: Iterable[PyTreeEntry] = ()) -> Self: """Create a new accessor instance.""" if not isinstance(path, (list, tuple)): path = tuple(path) if not all(isinstance(p, PyTreeEntry) for p in path): raise TypeError(f'Expected a path of PyTreeEntry, got {path!r}.') return super().__new__(cls, path)
[docs] def __call__(self, obj: Any, /) -> Any: """Get the child object.""" for entry in self: obj = entry(obj) return obj
@overload # type: ignore[override] def __getitem__(self, index: int, /) -> PyTreeEntry: ... @overload def __getitem__(self, index: slice, /) -> Self: ...
[docs] def __getitem__(self, index: int | slice, /) -> PyTreeEntry | Self: """Get the child path entry or an accessor for a subpath.""" if isinstance(index, slice): return self.__class__(super().__getitem__(index)) return super().__getitem__(index)
[docs] def __add__(self, other: object, /) -> Self: """Join the accessor with another path entry or accessor.""" if isinstance(other, PyTreeEntry): return self.__class__((*self, other)) if isinstance(other, PyTreeAccessor): return self.__class__((*self, *other)) return NotImplemented
[docs] def __mul__(self, value: int, /) -> Self: # type: ignore[override] """Repeat the accessor.""" return self.__class__(super().__mul__(value))
[docs] def __rmul__(self, value: int, /) -> Self: # type: ignore[override] """Repeat the accessor.""" return self.__class__(super().__rmul__(value))
[docs] def __eq__(self, other: object, /) -> bool: """Check if the accessors are equal.""" return isinstance(other, PyTreeAccessor) and super().__eq__(other)
[docs] def __hash__(self, /) -> int: """Get the hash of the accessor.""" return super().__hash__()
[docs] def __repr__(self, /) -> str: """Get the representation of the accessor.""" return f'{self.__class__.__name__}({self.codify()}, {super().__repr__()})'
[docs] def codify(self, /, root: str = '*') -> str: """Generate code for accessing the path.""" string = root for entry in self: string = entry.codify(string) return string
# These classes are used internally in the C++ side for accessor APIs _name, _cls = '', object for _name in __all__: _cls = globals()[_name] if not isinstance(_cls, type): # pragma: no cover raise TypeError(f'Expected a class, got {_cls!r}.') _cls.__module__ = 'optree' setattr(_C, _name, _cls) del _name, _cls