PyTree Node Registration
|
Extend the set of types that are considered internal nodes in pytrees. |
|
Extend the set of types that are considered internal nodes in pytrees. |
|
Remove a type from the pytree node registry. |
- optree.register_pytree_node(cls, /, flatten_func, unflatten_func, *, path_entry_type=<class 'optree.AutoEntry'>, namespace)[source]
Extend the set of types that are considered internal nodes in pytrees.
See also
register_pytree_node_class()andunregister_pytree_node().The
namespaceargument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.Warning
For safety reasons, a
namespacemust be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.- Parameters:
cls (type) – A Python type to treat as an internal pytree node.
flatten_func (callable) – A function to be used during flattening, taking an instance of
clsand returning a triple or optionally a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable metadata to be stored in the treespec and to be passed to theunflatten_func, and (3) (optional) an iterable for the tree path entries to the corresponding children. If the entries are not provided or given byNone, then range(len(children)) will be used.unflatten_func (callable) – A function taking two arguments: the metadata that was returned by
flatten_funcand stored in the treespec, and the unflattened children. The function should return an instance ofcls.path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default:
AutoEntry)namespace (str) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.
- Return type:
type[Collection[TypeVar(T)]]- Returns:
The same type as the input
cls.- Raises:
TypeError – If the input type is not a class.
TypeError – If the path entry class is not a subclass of
PyTreeEntry.TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is already registered in the registry.
Added in version 0.12.0: The
path_entry_typeargument to specify the path entry type used inPyTreeSpec.accessors()andtree_flatten_with_accessor(). If not provided,AutoEntrywill be used.Examples
>>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='set', ... ) <class 'set'>
>>> # Register a custom type into a namespace with accessor support >>> import types >>> # This can be whatever your container type is. >>> class MyContainer(types.SimpleNamespace): ... pass >>> # (Optional) Define a custom path entry type for accessor support. >>> # Here we showcase how to define one. In practice, you can use the built-in ``GetAttrEntry``. >>> class MyContainerEntry(PyTreeEntry): ... def __call__(self, obj): ... return getattr(obj, self.entry) ... def codify(self, node=''): ... return f'{node}.{self.entry}' >>> register_pytree_node( ... MyContainer, ... flatten_func=lambda ct: ( ... list(vars(ct).values()), ... list(vars(ct).keys()), ... list(vars(ct).keys()), ... ), ... unflatten_func=lambda keys, values: MyContainer(**dict(zip(keys, values))), ... path_entry_type=MyContainerEntry, ... namespace='mycontainer', ... ) <class '...MyContainer'>
>>> tree = {'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}
>>> # Flatten without specifying the namespace >>> tree_flatten(tree) # `MyContainer`s are leaf nodes ([MyContainer(lr=0.01, momentum=0.9), 1000], PyTreeSpec({'config': *, 'steps': *}))
>>> # Flatten with the namespace >>> leaves, treespec = tree_flatten(tree, namespace='mycontainer') >>> leaves, treespec ([0.01, 0.9, 1000], PyTreeSpec({'config': CustomTreeNode(MyContainer[['lr', 'momentum']], [*, *]), 'steps': *}, namespace='mycontainer'))
>>> # Custom ``entries`` are defined as attribute names >>> tree_paths(tree, namespace='mycontainer') [('config', 'lr'), ('config', 'momentum'), ('steps',)]
>>> # Custom path entry type defines the pytree access behavior >>> accessors = tree_accessors(tree, namespace='mycontainer') >>> accessors[0].codify() "*['config'].lr" >>> accessors[0](tree) 0.01
>>> # Unflatten back to a copy of the original object >>> tree_unflatten(treespec, leaves) {'config': MyContainer(lr=0.01, momentum=0.9), 'steps': 1000}
- optree.register_pytree_node_class(cls=None, /, *, path_entry_type=None, namespace=None)[source]
Extend the set of types that are considered internal nodes in pytrees.
See also
register_pytree_node()andunregister_pytree_node().The
namespaceargument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases.Warning
For safety reasons, a
namespacemust be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type.- Parameters:
cls (type, optional) – A Python type to treat as an internal pytree node.
path_entry_type (type, optional) – The type of the path entry to be used in the treespec. (default:
AutoEntry)namespace (str, optional) – A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type.
- Return type:
TypeVar(CustomTreeNodeType, bound=type[CustomTreeNode]) |Callable[[TypeVar(CustomTreeNodeType, bound=type[CustomTreeNode])],TypeVar(CustomTreeNodeType, bound=type[CustomTreeNode])]- Returns:
The same type as the input
clsif the argument presents. Otherwise, return a decorator function that registers the class as a pytree node.- Raises:
TypeError – If the path entry class is not a subclass of
PyTreeEntry.TypeError – If the namespace is not a string.
TypeError – If the class does not define the required method pairs.
ValueError – If the namespace is an empty string.
ValueError – If the type is already registered in the registry.
Added in version 0.12.0: The
TREE_PATH_ENTRY_TYPEclass variable to specify the path entry type used inPyTreeSpec.accessors()andtree_flatten_with_accessor(). If not provided,AutoEntrywill be used.Added in version 0.18.0: Previously, this function looked for methods named
tree_flattenandtree_unflattenfor the given class. Since version 0.18.0, it prefers methods named__tree_flatten__and__tree_unflatten__instead. The old method names are still supported for backward compatibility, but it is recommended to use the new method names. The method resolution follows this priority: 1. If both__tree_flatten__and__tree_unflatten__are defined, use them directly. 2. If bothtree_flattenandtree_unflattenare defined, wrap them as dunder methods. 3. If neither complete pair is available, raise aTypeErrorsuggesting the new method names.This function is a thin wrapper around
register_pytree_node(), and provides a class-oriented interface:@register_pytree_node_class(namespace='foo') class Special: TREE_PATH_ENTRY_TYPE = GetAttrEntry def __init__(self, x, y): self.x = x self.y = y def __tree_flatten__(self): return ((self.x, self.y), None, ('x', 'y')) @classmethod def __tree_unflatten__(cls, metadata, children): return cls(*children) @register_pytree_node_class('mylist') class MyList(UserList): TREE_PATH_ENTRY_TYPE = SequenceEntry def __tree_flatten__(self): return self.data, None, None @classmethod def __tree_unflatten__(cls, metadata, children): return cls(*children) # Legacy style (still supported but not recommended) @register_pytree_node_class(namespace='legacy') class LegacyStyleMyList(UserList): def tree_flatten(self): # Implementation automatically wrapped as __tree_flatten__ return self.data, None, None @classmethod def tree_unflatten(cls, metadata, children): # Implementation automatically wrapped as __tree_unflatten__ return cls(*children)
- optree.unregister_pytree_node(cls, /, *, namespace)[source]
Remove a type from the pytree node registry.
See also
register_pytree_node()andregister_pytree_node_class().This function is the inverse operation of function
register_pytree_node().- Parameters:
- Return type:
PyTreeNodeRegistryEntry- Returns:
The removed registry entry.
- Raises:
TypeError – If the input type is not a class.
TypeError – If the namespace is not a string.
ValueError – If the namespace is an empty string.
ValueError – If the type is a built-in type that cannot be unregistered.
ValueError – If the type is not found in the registry.
Examples
>>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='temp', ... ) <class 'set'>
>>> # Unregister the Python type >>> unregister_pytree_node(set, namespace='temp')