PyTree Node Registration

register_pytree_node(cls, /, flatten_func, ...)

Extend the set of types that are considered internal nodes in pytrees.

register_pytree_node_class([cls, ...])

Extend the set of types that are considered internal nodes in pytrees.

unregister_pytree_node(cls, /, *, namespace)

Remove a type from the pytree node registry.

optree.register_pytree_node(cls, /, flatten_func, unflatten_func, *, path_entry_type=<class 'optree.AutoEntry'>, namespace)[source]

Extend the set of types that are considered internal nodes in pytrees.

See also register_pytree_node_class() and unregister_pytree_node().

The namespace argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.

Warning

For safety reasons, a namespace must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.

Parameters:
  • cls (type) – A Python type to treat as an internal pytree node.

  • flatten_func (callable) – A function to be used during flattening, taking an instance of cls and returning a triple or optionally a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable metadata to be stored in the treespec and to be passed to the unflatten_func, and (3) (optional) an iterable for the tree path entries to the corresponding children. If the entries are not provided or given by None, then range(len(children)) will be used.

  • unflatten_func (callable) – A function taking two arguments: the metadata that was returned by flatten_func and stored in the treespec, and the unflattened children. The function should return an instance of cls.

  • path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default: AutoEntry)

  • namespace (str) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.

Return type:

type[Collection[TypeVar(T)]]

Returns:

The same type as the input cls.

Raises:
  • TypeError – If the input type is not a class.

  • TypeError – If the path entry class is not a subclass of PyTreeEntry.

  • TypeError – If the namespace is not a string.

  • ValueError – If the namespace is an empty string.

  • ValueError – If the type is already registered in the registry.

Added in version 0.12.0: The path_entry_type argument to specify the path entry type used in PyTreeSpec.accessors() and tree_flatten_with_accessor(). If not provided, AutoEntry will be used.

Examples

>>> # Register a Python type with lambda functions
>>> register_pytree_node(
...     set,
...     lambda s: (sorted(s), None, None),
...     lambda _, children: set(children),
...     namespace='set',
... )
<class 'set'>
>>> # Register a custom type into a namespace with accessor support
>>> import types
>>> # This can be whatever your container type is.
>>> class MyContainer(types.SimpleNamespace):
...     pass
>>> # (Optional) Define a custom path entry type for accessor support.
>>> # Here we showcase how to define one. In practice, you can use the built-in ``GetAttrEntry``.
>>> class MyContainerEntry(PyTreeEntry):
...     def __call__(self, obj):
...         return getattr(obj, self.entry)
...     def codify(self, node=''):
...         return f'{node}.{self.entry}'
>>> register_pytree_node(
...     MyContainer,
...     flatten_func=lambda ct: (
...         list(vars(ct).values()),
...         list(vars(ct).keys()),
...         list(vars(ct).keys()),
...     ),
...     unflatten_func=lambda keys, values: MyContainer(**dict(zip(keys, values))),
...     path_entry_type=MyContainerEntry,
...     namespace='mycontainer',
... )
<class '...MyContainer'>
>>> tree = {'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree)  # `MyContainer`s are leaf nodes
([MyContainer(lr=0.01, momentum=0.9), 1000], PyTreeSpec({'config': *, 'steps': *}))
>>> # Flatten with the namespace
>>> leaves, treespec = tree_flatten(tree, namespace='mycontainer')
>>> leaves, treespec
([0.01, 0.9, 1000], PyTreeSpec({'config': CustomTreeNode(MyContainer[['lr', 'momentum']], [*, *]), 'steps': *}, namespace='mycontainer'))
>>> # Custom ``entries`` are defined as attribute names
>>> tree_paths(tree, namespace='mycontainer')
[('config', 'lr'), ('config', 'momentum'), ('steps',)]
>>> # Custom path entry type defines the pytree access behavior
>>> accessors = tree_accessors(tree, namespace='mycontainer')
>>> accessors[0].codify()
"*['config'].lr"
>>> accessors[0](tree)
0.01
>>> # Unflatten back to a copy of the original object
>>> tree_unflatten(treespec, leaves)
{'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}
optree.register_pytree_node_class(cls=None, /, *, path_entry_type=None, namespace=None)[source]

Extend the set of types that are considered internal nodes in pytrees.

See also register_pytree_node() and unregister_pytree_node().

The namespace argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.

Warning

For safety reasons, a namespace must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.

Parameters:
  • cls (type, optional) – A Python type to treat as an internal pytree node.

  • path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default: AutoEntry)

  • namespace (str, optional) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.

Return type:

TypeVar(CustomTreeNodeType, bound= type[CustomTreeNode]) | Callable[[TypeVar(CustomTreeNodeType, bound= type[CustomTreeNode])], TypeVar(CustomTreeNodeType, bound= type[CustomTreeNode])]

Returns:

The same type as the input cls if the argument presents. Otherwise, return a decorator function that registers the class as a pytree node.

Raises:
  • TypeError – If the path entry class is not a subclass of PyTreeEntry.

  • TypeError – If the namespace is not a string.

  • TypeError – If the class does not define the required method pairs.

  • ValueError – If the namespace is an empty string.

  • ValueError – If the type is already registered in the registry.

Added in version 0.12.0: The TREE_PATH_ENTRY_TYPE class variable to specify the path entry type used in PyTreeSpec.accessors() and tree_flatten_with_accessor(). If not provided, AutoEntry will be used.

Added in version 0.18.0: Previously, this function looked for methods named tree_flatten and tree_unflatten for the given class. Since version 0.18.0, it prefers methods named __tree_flatten__ and __tree_unflatten__ instead. The old method names are still supported for backward compatibility, but it is recommended to use the new method names. The method resolution follows this priority: 1. If both __tree_flatten__ and __tree_unflatten__ are defined, use them directly. 2. If both tree_flatten and tree_unflatten are defined, wrap them as dunder methods. 3. If neither complete pair is available, raise a TypeError suggesting the new method names.

This function is a thin wrapper around register_pytree_node(), and provides a class-oriented interface:

@register_pytree_node_class(namespace='foo')
class Special:
    TREE_PATH_ENTRY_TYPE = GetAttrEntry

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __tree_flatten__(self):
        return ((self.x, self.y), None, ('x', 'y'))

    @classmethod
    def __tree_unflatten__(cls, metadata, children):
        return cls(*children)

@register_pytree_node_class('mylist')
class MyList(UserList):
    TREE_PATH_ENTRY_TYPE = SequenceEntry

    def __tree_flatten__(self):
        return self.data, None, None

    @classmethod
    def __tree_unflatten__(cls, metadata, children):
        return cls(*children)

# Legacy style (still supported but not recommended)
@register_pytree_node_class(namespace='legacy')
class LegacyStyleMyList(UserList):
    def tree_flatten(self):
        # Implementation automatically wrapped as __tree_flatten__
        return self.data, None, None

    @classmethod
    def tree_unflatten(cls, metadata, children):
        # Implementation automatically wrapped as __tree_unflatten__
        return cls(*children)
optree.unregister_pytree_node(cls, /, *, namespace)[source]

Remove a type from the pytree node registry.

See also register_pytree_node() and register_pytree_node_class().

This function is the inverse operation of function register_pytree_node().

Parameters:
  • cls (type) – A Python type to remove from the pytree node registry.

  • namespace (str) – The namespace of the pytree node registry to remove the type from.

Return type:

PyTreeNodeRegistryEntry

Returns:

The removed registry entry.

Raises:
  • TypeError – If the input type is not a class.

  • TypeError – If the namespace is not a string.

  • ValueError – If the namespace is an empty string.

  • ValueError – If the type is a built-in type that cannot be unregistered.

  • ValueError – If the type is not found in the registry.

Examples

>>> # Register a Python type with lambda functions
>>> register_pytree_node(
...     set,
...     lambda s: (sorted(s), None, None),
...     lambda _, children: set(children),
...     namespace='temp',
... )
<class 'set'>
>>> # Unregister the Python type
>>> unregister_pytree_node(set, namespace='temp')