Integration with dataclasses
PyTree integration with dataclasses.
This module implements PyTree integration with dataclasses by redefining the field(),
dataclass(), and make_dataclass() functions. The register_node() function allows
registering existing dataclasses.dataclass()-decorated classes as pytree nodes. Other APIs
are re-exported from the original dataclasses module.
The PyTree integration allows dataclasses to be flattened and unflattened recursively. The fields
are stored in a special attribute named __optree_dataclass_fields__ in the dataclass.
>>> import math
... import optree
...
>>> @optree.dataclasses.dataclass(namespace='my_module')
... class Point:
... x: float
... y: float
... z: float = 0.0
... norm: float = optree.dataclasses.field(init=False, pytree_node=False)
...
... def __post_init__(self) -> None:
... self.norm = math.hypot(self.x, self.y, self.z)
...
>>> point = Point(2.0, 6.0, 3.0)
>>> point
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
>>> # Flatten without specifying the namespace
>>> optree.tree_flatten(point) # `Point`s are leaf nodes
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
>>> # Flatten with the namespace
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
>>> accessors, leaves, treespec
(
[
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
],
[2.0, 6.0, 3.0],
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
)
>>> point == optree.tree_unflatten(treespec, leaves)
True
|
Field factory for |
|
Dataclass decorator with PyTree integration. |
|
Make a new dynamically created dataclass with PyTree integration. |
- optree.dataclasses.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, doc=None, pytree_node=None)[source]
Field factory for
dataclass().This factory function is used to define the fields in a dataclass. It is similar to the field factory
dataclasses.field(), but with an additionalpytree_nodeparameter. Ifpytree_nodeisTrue(default), the field will be considered a child node in the PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will be considered as PyTree metadata.Setting
pytree_nodein the field factory is equivalent to setting a key'pytree_node'inmetadatain the original field factory. Thepytree_nodevalue can be accessed usingfield.metadata['pytree_node']. Ifpytree_nodeisNone, the valuemetadata.get('pytree_node', True)will be used.Note
If a field is considered a child node, it must be included in the argument list of the
__init__()method, i.e., passesinit=Truein the field factory.- Parameters:
pytree_node (bool or None, optional) – Whether the field is a PyTree node.
**kwargs (optional) – Optional keyword arguments passed to
dataclasses.field().
- Returns:
The field defined using the provided arguments with
field.metadata['pytree_node']set.- Return type:
- optree.dataclasses.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, namespace)[source]
Dataclass decorator with PyTree integration.
- Parameters:
cls (type or None, optional) – The class to decorate. If
None, return a decorator.namespace (str) – The registry namespace used for the PyTree registration.
**kwargs (optional) – Optional keyword arguments passed to
dataclasses.dataclass().
- Returns:
The decorated class with PyTree integration or decorator function.
- Return type:
type or callable
- optree.dataclasses.make_dataclass(cls_name, fields, *, bases=(), ns=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, module=None, decorator=<function dataclass>, namespace)[source]
Make a new dynamically created dataclass with PyTree integration.
The dataclass name will be
cls_name.fieldsis an iterable of either (name), (name, type), or (name, type, Field) objects. If type is omitted, use the stringtyping.Any. Field objects are created by the equivalent of callingfield()(name, type [, Field-info]).The
namespaceparameter is the PyTree registration namespace which should be a string. Thenamespacein the originaldataclasses.make_dataclass()function is renamed tonsto avoid conflicts.The remaining parameters are passed to
dataclasses.make_dataclass(). Seedataclasses.make_dataclass()for more information.- Parameters:
cls_name (
str) – The name of the dataclass.fields (Iterable[str | tuple[str, Any] | tuple[str, Any, Any]]) – An iterable of either (name), (name, type), or (name, type, Field) objects.
namespace (str) – The registry namespace used for the PyTree registration.
ns (dict or None, optional) – The namespace used in dynamic type creation. See
dataclasses.make_dataclass()and the builtintype()function for more information.**kwargs (optional) – Optional keyword arguments passed to
dataclasses.make_dataclass().
- Returns:
The dynamically created dataclass with PyTree integration.
- Return type: