Integrations with Third-Party Libraries
Integration for attrs
|
Field factory for |
|
Attrs class decorator with PyTree integration. |
|
Frozen attrs class decorator with PyTree integration. |
|
Alias for |
|
Create a new attrs class and register it as a pytree node. |
|
Register an existing attrs class as a pytree node. |
|
A path entry class for attrs classes. |
- optree.integrations.attrs.field(**kwargs)[source]
Field factory for
define().This factory function is used to define the fields in an attrs class. It is similar to
attrs.field(), but with an additionalpytree_nodeparameter. Ifpytree_nodeisTrue(default), the field will be considered a child node in the PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will be considered as PyTree metadata.Setting
pytree_nodein the field factory is equivalent to setting a key'pytree_node'inmetadata. Thepytree_nodevalue can be accessed usingfield.metadata['pytree_node']. Ifpytree_nodeisNone, the valuemetadata.get('pytree_node', True)will be used.Note
If a field is considered a child node, it must be included in the argument list of the
__init__()method, i.e., passesinit=Truein the field factory.- Parameters:
pytree_node (bool or None, optional) – Whether the field is a PyTree node.
**kwargs (optional) – Optional keyword arguments passed to
attrs.field().
- Return type:
- Returns:
The field defined using the provided arguments with
metadata['pytree_node']set.
Added in version 0.20.0.
- optree.integrations.attrs.define(cls=None, /, *, namespace, **kwargs)[source]
Attrs class decorator with PyTree integration.
This is a wrapper around
attrs.define()that also registers the class as a pytree node.- Parameters:
- Returns:
The decorated class with PyTree integration or decorator function.
- Return type:
type or callable
Added in version 0.20.0.
- optree.integrations.attrs.frozen(cls=None, /, *, namespace, **kwargs)[source]
Frozen attrs class decorator with PyTree integration.
This is a convenience wrapper around
define()withfrozen=True.- Parameters:
- Returns:
The decorated class with PyTree integration or decorator function.
- Return type:
type or callable
Added in version 0.20.0.
- optree.integrations.attrs.make_class(name, attrs, /, *, namespace, **kwargs)[source]
Create a new attrs class and register it as a pytree node.
This is a wrapper around
attrs.make_class()that also registers the class as a pytree node.- Parameters:
- Returns:
A new attrs class registered as a pytree node.
- Return type:
Added in version 0.20.0.
- optree.integrations.attrs.register_node(cls=None, /, *, namespace=None)[source]
Register an existing attrs class as a pytree node.
This function takes an existing
attrs.define()-decorated class and registers it as a pytree node. It can be used as a direct function call or as a decorator.Fields with
metadata['pytree_node']set toTrue(or not set, defaulting toTrue) are treated as children, while init fields withmetadata['pytree_node']set toFalseare treated as metadata.Usage:
# Direct function call register_node(Point, namespace='my-namespace') # As a decorator @register_node(namespace='my-namespace') @attrs.define class Point: x: float y: float
- Parameters:
- Returns:
The same class, now registered as a pytree node, or a decorator function.
- Return type:
type or callable
Added in version 0.20.0.
Integration for JAX
|
Ravel (flatten) a pytree of arrays down to a 1D array. |
- optree.integrations.jax.tree_ravel(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = { ... 'layer1': { ... 'weight': jnp.arange(0, 6, dtype=jnp.float32).reshape((2, 3)), ... 'bias': jnp.arange(6, 8, dtype=jnp.float32).reshape((2,)), ... }, ... 'layer2': { ... 'weight': jnp.arange(8, 10, dtype=jnp.float32).reshape((1, 2)), ... 'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,)), ... }, ... } >>> tree { 'layer1': { 'weight': Array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': Array([6., 7.], dtype=float32) }, 'layer2': { 'weight': Array([[8., 9.]], dtype=float32), 'bias': Array([10.], dtype=float32) } } >>> flat, unravel_func = tree_ravel(tree) >>> flat Array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32) >>> unravel_func(flat) { 'layer1': { 'weight': Array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': Array([6., 7.], dtype=float32) }, 'layer2': { 'weight': Array([[8., 9.]], dtype=float32), 'bias': Array([10.], dtype=float32) } }
- Parameters:
tree (pytree) – a pytree of arrays and scalars to ravel.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A pair
(array, unravel_func)where the first element is a 1D array representing the flattened and concatenated leaf values, withdtypedetermined by promoting thedtypes of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the inputtree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.
Integration for NumPy
|
Ravel (flatten) a pytree of arrays down to a 1D array. |
- optree.integrations.numpy.tree_ravel(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = { ... 'layer1': { ... 'weight': np.arange(0, 6, dtype=np.float32).reshape((2, 3)), ... 'bias': np.arange(6, 8, dtype=np.float32).reshape((2,)), ... }, ... 'layer2': { ... 'weight': np.arange(8, 10, dtype=np.float32).reshape((1, 2)), ... 'bias': np.arange(10, 11, dtype=np.float32).reshape((1,)), ... }, ... } >>> tree { 'layer1': { 'weight': array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': array([6., 7.], dtype=float32) }, 'layer2': { 'weight': array([[8., 9.]], dtype=float32), 'bias': array([10.], dtype=float32) } } >>> flat, unravel_func = tree_ravel(tree) >>> flat array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32) >>> unravel_func(flat) { 'layer1': { 'weight': array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': array([6., 7.], dtype=float32) }, 'layer2': { 'weight': array([[8., 9.]], dtype=float32), 'bias': array([10.], dtype=float32) } }
- Parameters:
tree (pytree) – a pytree of arrays and scalars to ravel.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A pair
(array, unravel_func)where the first element is a 1D array representing the flattened and concatenated leaf values, withdtypedetermined by promoting thedtypes of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the inputtree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.
Integration for PyTorch
|
Ravel (flatten) a pytree of tensors down to a 1D tensor. |
- optree.integrations.torch.tree_ravel(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Ravel (flatten) a pytree of tensors down to a 1D tensor.
>>> tree = { ... 'layer1': { ... 'weight': torch.arange(0, 6, dtype=torch.float64).reshape((2, 3)), ... 'bias': torch.arange(6, 8, dtype=torch.float64).reshape((2,)), ... }, ... 'layer2': { ... 'weight': torch.arange(8, 10, dtype=torch.float64).reshape((1, 2)), ... 'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,)), ... }, ... } >>> tree { 'layer1': { 'weight': tensor([[0., 1., 2.], [3., 4., 5.]], dtype=torch.float64), 'bias': tensor([6., 7.], dtype=torch.float64) }, 'layer2': { 'weight': tensor([[8., 9.]], dtype=torch.float64), 'bias': tensor([10.], dtype=torch.float64) } } >>> flat, unravel_func = tree_ravel(tree) >>> flat tensor([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=torch.float64) >>> unravel_func(flat) { 'layer1': { 'weight': tensor([[0., 1., 2.], [3., 4., 5.]], dtype=torch.float64), 'bias': tensor([6., 7.], dtype=torch.float64) }, 'layer2': { 'weight': tensor([[8., 9.]], dtype=torch.float64), 'bias': tensor([10.], dtype=torch.float64) } }
- Parameters:
tree (pytree) – a pytree of tensors to ravel.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
Truestopping the traversal and the whole subtree being treated as a leaf, andFalseindicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
Noneas a leaf. IfFalse,Noneis a non-leaf node with arity 0. ThusNoneis contained in the treespec rather than in the leaves list andNonewill remain in the result pytree. (default:False)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
'', i.e., the global namespace)
- Return type:
- Returns:
A pair
(tensor, unravel_func)where the first element is a 1D tensor representing the flattened and concatenated leaf values, withdtypedetermined by promoting thedtypes of leaf values, and the second element is a callable for unflattening a 1D tensor of the same length back to a pytree of the same structure as the inputtree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty tensor of the default dtype is returned in the first component of the output.