PyTree Node Registry
|
Extend the set of types that are considered internal nodes in pytrees. |
|
Extend the set of types that are considered internal nodes in pytrees. |
|
Remove a type from the pytree node registry. |
- optree.register_pytree_node(cls, flatten_func, unflatten_func, *, path_entry_type=<class 'optree.accessor.AutoEntry'>, namespace)[source]
Extend the set of types that are considered internal nodes in pytrees.
See also
register_pytree_node_class()
andunregister_pytree_node()
.The
namespace
argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.Warning
For safety reasons, a
namespace
must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.- Parameters:
cls (type) – A Python type to treat as an internal pytree node.
flatten_func (callable) – A function to be used during flattening, taking an instance of
cls
and returning a triple or optionally a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be passed to theunflatten_func
, and (3) (optional) an iterable for the tree path entries to the corresponding children. If the entries are not provided or given byNone
, then range(len(children)) will be used.unflatten_func (callable) – A function taking two arguments: the auxiliary data that was returned by
flatten_func
and stored in the treespec, and the unflattened children. The function should return an instance ofcls
.path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default:
AutoEntry
)namespace (str) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.
- Return type:
Type
[CustomTreeNode
[TypeVar
(T
)]]- Returns:
The same type as the input
cls
.- Raises:
TypeError – If the input type is not a class.
TypeError – If the path entry class is not a subclass of
PyTreeEntry
.TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is already registered in the registry.
Examples
>>> # Registry a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='set', ... ) <class 'set'>
>>> # Register a Python type into a namespace >>> import torch >>> register_pytree_node( ... torch.Tensor, ... flatten_func=lambda tensor: ( ... (tensor.cpu().detach().numpy(),), ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, ... ), ... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), ... namespace='torch2numpy', ... ) <class 'torch.Tensor'>
>>> >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} >>> tree {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # Flatten without specifying the namespace >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # Flatten with the namespace >>> tree_flatten(tree, namespace='torch2numpy') ( [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], PyTreeSpec( { 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]), 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*]) }, namespace='torch2numpy' ) )
>>> # Register the same type with a different namespace for different behaviors >>> def tensor2flatparam(tensor): ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None ... ... def flatparam2tensor(metadata, children): ... return children[0].reshape(metadata) ... ... register_pytree_node( ... torch.Tensor, ... flatten_func=tensor2flatparam, ... unflatten_func=flatparam2tensor, ... namespace='tensor2flatparam', ... ) <class 'torch.Tensor'>
>>> # Flatten with the new namespace >>> tree_flatten(tree, namespace='tensor2flatparam') ( [ Parameter containing: tensor([0., 0.], requires_grad=True), Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) ], PyTreeSpec( { 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) }, namespace='tensor2flatparam' ) )
- optree.register_pytree_node_class(cls=None, *, path_entry_type=None, namespace=None)[source]
Extend the set of types that are considered internal nodes in pytrees.
See also
register_pytree_node()
andunregister_pytree_node()
.The
namespace
argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.Warning
For safety reasons, a
namespace
must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.- Parameters:
cls (type, optional) – A Python type to treat as an internal pytree node.
path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default:
AutoEntry
)namespace (str, optional) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.
- Return type:
Union
[Type
[CustomTreeNode
[TypeVar
(T
)]],Callable
[[Type
[CustomTreeNode
[TypeVar
(T
)]]],Type
[CustomTreeNode
[TypeVar
(T
)]]]]- Returns:
The same type as the input
cls
if the argument presents. Otherwise, return a decorator function that registers the class as a pytree node.- Raises:
TypeError – If the path entry class is not a subclass of
PyTreeEntry
.TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is already registered in the registry.
This function is a thin wrapper around
register_pytree_node()
, and provides a class-oriented interface:@register_pytree_node_class(namespace='foo') class Special: TREE_PATH_ENTRY_TYPE = GetAttrEntry def __init__(self, x, y): self.x = x self.y = y def tree_flatten(self): return ((self.x, self.y), None, ('x', 'y')) @classmethod def tree_unflatten(cls, metadata, children): return cls(*children) @register_pytree_node_class('mylist') class MyList(UserList): TREE_PATH_ENTRY_TYPE = SequenceEntry def tree_flatten(self): return self.data, None, None @classmethod def tree_unflatten(cls, metadata, children): return cls(*children)
- optree.unregister_pytree_node(cls, *, namespace)[source]
Remove a type from the pytree node registry.
See also
register_pytree_node()
andregister_pytree_node_class()
.This function is the inverse operation of function
register_pytree_node()
.- Parameters:
- Return type:
PyTreeNodeRegistryEntry
- Returns:
The removed registry entry.
- Raises:
TypeError – If the input type is not a class.
TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is a built-in type that cannot be unregistered.
ValueError – If the type is not found in the registry.
Examples
>>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='temp', ... ) <class 'set'>
>>> # Unregister the Python type >>> unregister_pytree_node(set, namespace='temp')