Integration with dataclasses

PyTree integration with dataclasses.

This module implements PyTree integration with dataclasses by redefining the field(), dataclass(), and make_dataclass() functions. Other APIs are re-exported from the original dataclasses module.

The PyTree integration allows dataclasses to be flattened and unflattened recursively. The fields are stored in a special attribute named __optree_dataclass_fields__ in the dataclass.

>>> import math
... import optree
...
>>> @optree.dataclasses.dataclass(namespace='my_module')
... class Point:
...     x: float
...     y: float
...     z: float = 0.0
...     norm: float = optree.dataclasses.field(init=False, pytree_node=False)
...
...     def __post_init__(self) -> None:
...         self.norm = math.hypot(self.x, self.y, self.z)
...
>>> point = Point(2.0, 6.0, 3.0)
>>> point
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
>>> # Flatten without specifying the namespace
>>> optree.tree_flatten(point)  # `Point`s are leaf nodes
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
>>> # Flatten with the namespace
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
>>> accessors, leaves, treespec
(
    [
        PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
        PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
        PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
    ],
    [2.0, 6.0, 3.0],
    PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
)
>>> point == optree.tree_unflatten(treespec, leaves)
True

field(*[, default, default_factory, init, ...])

Field factory for dataclass().

dataclass([cls, init, repr, eq, order, ...])

Dataclass decorator with PyTree integration.

make_dataclass(cls_name, fields, *[, bases, ...])

Make a new dynamically created dataclass with PyTree integration.

optree.dataclasses.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=<dataclasses._MISSING_TYPE object>, doc=None, pytree_node=None)[source]

Field factory for dataclass().

This factory function is used to define the fields in a dataclass. It is similar to the field factory dataclasses.field(), but with an additional pytree_node parameter. If pytree_node is True (default), the field will be considered a child node in the PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will be considered as PyTree metadata.

Setting pytree_node in the field factory is equivalent to setting a key 'pytree_node' in metadata in the original field factory. The pytree_node value can be accessed using field.metadata['pytree_node']. If pytree_node is None, the value metadata.get('pytree_node', True) will be used.

Note

If a field is considered a child node, it must be included in the argument list of the __init__() method, i.e., passes init=True in the field factory.

Parameters:
  • pytree_node (bool or None, optional) – Whether the field is a PyTree node.

  • **kwargs (optional) – Optional keyword arguments passed to dataclasses.field().

Returns:

The field defined using the provided arguments with field.metadata['pytree_node'] set.

Return type:

dataclasses.Field

optree.dataclasses.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, namespace)[source]

Dataclass decorator with PyTree integration.

Parameters:
  • cls (type or None, optional) – The class to decorate. If None, return a decorator.

  • namespace (str) – The registry namespace used for the PyTree registration.

  • **kwargs (optional) – Optional keyword arguments passed to dataclasses.dataclass().

Returns:

The decorated class with PyTree integration or decorator function.

Return type:

type or callable

optree.dataclasses.make_dataclass(cls_name, fields, *, bases=(), ns=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, module=None, decorator=<function dataclass>, namespace)[source]

Make a new dynamically created dataclass with PyTree integration.

The dataclass name will be cls_name. fields is an iterable of either (name), (name, type), or (name, type, Field) objects. If type is omitted, use the string typing.Any. Field objects are created by the equivalent of calling field() (name, type [, Field-info]).

The namespace parameter is the PyTree registration namespace which should be a string. The namespace in the original dataclasses.make_dataclass() function is renamed to ns to avoid conflicts.

The remaining parameters are passed to dataclasses.make_dataclass(). See dataclasses.make_dataclass() for more information.

Parameters:
  • cls_name (str) – The name of the dataclass.

  • fields (Iterable[str | tuple[str, Any] | tuple[str, Any, Any]]) – An iterable of either (name), (name, type), or (name, type, Field) objects.

  • namespace (str) – The registry namespace used for the PyTree registration.

  • ns (dict or None, optional) – The namespace used in dynamic type creation. See dataclasses.make_dataclass() and the builtin type() function for more information.

  • **kwargs (optional) – Optional keyword arguments passed to dataclasses.make_dataclass().

Returns:

The dynamically created dataclass with PyTree integration.

Return type:

type