Integration with dataclasses
PyTree integration with dataclasses
.
This module implements PyTree integration with dataclasses
by redefining the field()
,
dataclass()
, and make_dataclass()
functions. Other APIs are re-exported from the
original dataclasses
module.
The PyTree integration allows dataclasses to be flattened and unflattened recursively. The fields
are stored in a special attribute named __optree_dataclass_fields__
in the dataclass.
>>> import math
... import optree
...
>>> @optree.dataclasses.dataclass(namespace='my_module')
... class Point:
... x: float
... y: float
... z: float = 0.0
... norm: float = optree.dataclasses.field(init=False, pytree_node=False)
...
... def __post_init__(self) -> None:
... self.norm = math.hypot(self.x, self.y, self.z)
...
>>> point = Point(2.0, 6.0, 3.0)
>>> point
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
>>> # Flatten without specifying the namespace
>>> optree.tree_flatten(point) # `Point`s are leaf nodes
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
>>> # Flatten with the namespace
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
>>> accessors, leaves, treespec
(
[
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
],
[2.0, 6.0, 3.0],
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
)
>>> point == optree.tree_unflatten(treespec, leaves)
True
|
Field factory for |
|
Dataclass decorator with PyTree integration. |
|
Make a new dynamically created dataclass with PyTree integration. |
- optree.dataclasses.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, doc=None, pytree_node=None)[source]
Field factory for
dataclass()
.This factory function is used to define the fields in a dataclass. It is similar to the field factory
dataclasses.field()
, but with an additionalpytree_node
parameter. Ifpytree_node
isTrue
(default), the field will be considered a child node in the PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will be considered as PyTree metadata.Setting
pytree_node
in the field factory is equivalent to setting a key'pytree_node'
inmetadata
in the original field factory. Thepytree_node
value can be accessed usingfield.metadata['pytree_node']
. Ifpytree_node
isNone
, the valuemetadata.get('pytree_node', True)
will be used.Note
If a field is considered a child node, it must be included in the argument list of the
__init__()
method, i.e., passesinit=True
in the field factory.- Parameters:
pytree_node (bool or None, optional) – Whether the field is a PyTree node.
**kwargs (optional) – Optional keyword arguments passed to
dataclasses.field()
.
- Returns:
The field defined using the provided arguments with
field.metadata['pytree_node']
set.- Return type:
- optree.dataclasses.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, namespace)[source]
Dataclass decorator with PyTree integration.
- Parameters:
cls (type or None, optional) – The class to decorate. If
None
, return a decorator.namespace (str) – The registry namespace used for the PyTree registration.
**kwargs (optional) – Optional keyword arguments passed to
dataclasses.dataclass()
.
- Returns:
The decorated class with PyTree integration or decorator function.
- Return type:
type or callable
- optree.dataclasses.make_dataclass(cls_name, fields, *, bases=(), ns=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, module=None, decorator=<function dataclass>, namespace)[source]
Make a new dynamically created dataclass with PyTree integration.
The dataclass name will be
cls_name
.fields
is an iterable of either (name), (name, type), or (name, type, Field) objects. If type is omitted, use the stringtyping.Any
. Field objects are created by the equivalent of callingfield()
(name, type [, Field-info]).The
namespace
parameter is the PyTree registration namespace which should be a string. Thenamespace
in the originaldataclasses.make_dataclass()
function is renamed tons
to avoid conflicts.The remaining parameters are passed to
dataclasses.make_dataclass()
. Seedataclasses.make_dataclass()
for more information.- Parameters:
cls_name (
str
) – The name of the dataclass.fields (Iterable[str | tuple[str, Any] | tuple[str, Any, Any]]) – An iterable of either (name), (name, type), or (name, type, Field) objects.
namespace (str) – The registry namespace used for the PyTree registration.
ns (dict or None, optional) – The namespace used in dynamic type creation. See
dataclasses.make_dataclass()
and the builtintype()
function for more information.**kwargs (optional) – Optional keyword arguments passed to
dataclasses.make_dataclass()
.
- Returns:
The dynamically created dataclass with PyTree integration.
- Return type: