PyTree Node Registry

register_pytree_node(cls, flatten_func, ...)

Extend the set of types that are considered internal nodes in pytrees.

register_pytree_node_class([cls, ...])

Extend the set of types that are considered internal nodes in pytrees.

unregister_pytree_node(cls, *, namespace)

Remove a type from the pytree node registry.

optree.register_pytree_node(cls, flatten_func, unflatten_func, *, path_entry_type=<class 'optree.accessor.AutoEntry'>, namespace)[source]

Extend the set of types that are considered internal nodes in pytrees.

See also register_pytree_node_class() and unregister_pytree_node().

The namespace argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.

Warning

For safety reasons, a namespace must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.

Parameters:
  • cls (type) – A Python type to treat as an internal pytree node.

  • flatten_func (callable) – A function to be used during flattening, taking an instance of cls and returning a triple or optionally a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be passed to the unflatten_func, and (3) (optional) an iterable for the tree path entries to the corresponding children. If the entries are not provided or given by None, then range(len(children)) will be used.

  • unflatten_func (callable) – A function taking two arguments: the auxiliary data that was returned by flatten_func and stored in the treespec, and the unflattened children. The function should return an instance of cls.

  • path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default: AutoEntry)

  • namespace (str) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.

Return type:

Type[CustomTreeNode[TypeVar(T)]]

Returns:

The same type as the input cls.

Raises:
  • TypeError – If the input type is not a class.

  • TypeError – If the path entry class is not a subclass of PyTreeEntry.

  • TypeError – If the namespace is not a string.

  • ValueError – If the namespace is an empty string.

  • ValueError – If the type is already registered in the registry.

Examples

>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
...     set,
...     lambda s: (sorted(s), None, None),
...     lambda _, children: set(children),
...     namespace='set',
... )
<class 'set'>
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
...     torch.Tensor,
...     flatten_func=lambda tensor: (
...         (tensor.cpu().detach().numpy(),),
...         {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
...     ),
...     unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
...     namespace='torch2numpy',
... )
<class 'torch.Tensor'>
>>> 
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree)  # `torch.Tensor`s are leaf nodes
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy')
(
    [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
    PyTreeSpec(
        {
            'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
            'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
        },
        namespace='torch2numpy'
    )
)
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
...     return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
... def flatparam2tensor(metadata, children):
...     return children[0].reshape(metadata)
...
... register_pytree_node(
...     torch.Tensor,
...     flatten_func=tensor2flatparam,
...     unflatten_func=flatparam2tensor,
...     namespace='tensor2flatparam',
... )
<class 'torch.Tensor'>
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam')
(
    [
        Parameter containing: tensor([0., 0.], requires_grad=True),
        Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
    ],
    PyTreeSpec(
        {
            'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
            'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
        },
        namespace='tensor2flatparam'
    )
)
optree.register_pytree_node_class(cls=None, *, path_entry_type=None, namespace=None)[source]

Extend the set of types that are considered internal nodes in pytrees.

See also register_pytree_node() and unregister_pytree_node().

The namespace argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.

Warning

For safety reasons, a namespace must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.

Parameters:
  • cls (type, optional) – A Python type to treat as an internal pytree node.

  • path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default: AutoEntry)

  • namespace (str, optional) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.

Return type:

Union[Type[CustomTreeNode[TypeVar(T)]], Callable[[Type[CustomTreeNode[TypeVar(T)]]], Type[CustomTreeNode[TypeVar(T)]]]]

Returns:

The same type as the input cls if the argument presents. Otherwise, return a decorator function that registers the class as a pytree node.

Raises:
  • TypeError – If the path entry class is not a subclass of PyTreeEntry.

  • TypeError – If the namespace is not a string.

  • ValueError – If the namespace is an empty string.

  • ValueError – If the type is already registered in the registry.

This function is a thin wrapper around register_pytree_node(), and provides a class-oriented interface:

@register_pytree_node_class(namespace='foo')
class Special:
    TREE_PATH_ENTRY_TYPE = GetAttrEntry

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def tree_flatten(self):
        return ((self.x, self.y), None, ('x', 'y'))

    @classmethod
    def tree_unflatten(cls, metadata, children):
        return cls(*children)

@register_pytree_node_class('mylist')
class MyList(UserList):
    TREE_PATH_ENTRY_TYPE = SequenceEntry

    def tree_flatten(self):
        return self.data, None, None

    @classmethod
    def tree_unflatten(cls, metadata, children):
        return cls(*children)
optree.unregister_pytree_node(cls, *, namespace)[source]

Remove a type from the pytree node registry.

See also register_pytree_node() and register_pytree_node_class().

This function is the inverse operation of function register_pytree_node().

Parameters:
  • cls (type) – A Python type to remove from the pytree node registry.

  • namespace (str) – The namespace of the pytree node registry to remove the type from.

Return type:

PyTreeNodeRegistryEntry

Returns:

The removed registry entry.

Raises:
  • TypeError – If the input type is not a class.

  • TypeError – If the namespace is not a string.

  • ValueError – If the namespace is an empty string.

  • ValueError – If the type is a built-in type that cannot be unregistered.

  • ValueError – If the type is not found in the registry.

Examples

>>> # Register a Python type with lambda functions
>>> register_pytree_node(
...     set,
...     lambda s: (sorted(s), None, None),
...     lambda _, children: set(children),
...     namespace='temp',
... )
<class 'set'>
>>> # Unregister the Python type
>>> unregister_pytree_node(set, namespace='temp')