Integrations with Third-Party Libraries

Integration for attrs

field(**kwargs)

Field factory for define().

define([cls])

Attrs class decorator with PyTree integration.

frozen([cls])

Frozen attrs class decorator with PyTree integration.

mutable([cls])

Alias for define().

make_class(name, attrs, /, *, namespace, ...)

Create a new attrs class and register it as a pytree node.

register_node([cls, namespace])

Register an existing attrs class as a pytree node.

AttrsEntry(entry, type, kind)

A path entry class for attrs classes.

optree.integrations.attrs.field(**kwargs)[source]

Field factory for define().

This factory function is used to define the fields in an attrs class. It is similar to attrs.field(), but with an additional pytree_node parameter. If pytree_node is True (default), the field will be considered a child node in the PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will be considered as PyTree metadata.

Setting pytree_node in the field factory is equivalent to setting a key 'pytree_node' in metadata. The pytree_node value can be accessed using field.metadata['pytree_node']. If pytree_node is None, the value metadata.get('pytree_node', True) will be used.

Note

If a field is considered a child node, it must be included in the argument list of the __init__() method, i.e., passes init=True in the field factory.

Parameters:
  • pytree_node (bool or None, optional) – Whether the field is a PyTree node.

  • **kwargs (optional) – Optional keyword arguments passed to attrs.field().

Return type:

Any

Returns:

The field defined using the provided arguments with metadata['pytree_node'] set.

Added in version 0.20.0.

optree.integrations.attrs.define(cls=None, /, *, namespace, **kwargs)[source]

Attrs class decorator with PyTree integration.

This is a wrapper around attrs.define() that also registers the class as a pytree node.

Parameters:
  • cls (type or None, optional) – The class to decorate. If None, return a decorator.

  • namespace (str) – The registry namespace used for the PyTree registration.

  • **kwargs (optional) – Optional keyword arguments passed to attrs.define().

Returns:

The decorated class with PyTree integration or decorator function.

Return type:

type or callable

Added in version 0.20.0.

optree.integrations.attrs.frozen(cls=None, /, *, namespace, **kwargs)[source]

Frozen attrs class decorator with PyTree integration.

This is a convenience wrapper around define() with frozen=True.

Parameters:
  • cls (type or None, optional) – The class to decorate. If None, return a decorator.

  • namespace (str) – The registry namespace used for the PyTree registration.

  • **kwargs (optional) – Optional keyword arguments passed to attrs.define().

Returns:

The decorated class with PyTree integration or decorator function.

Return type:

type or callable

Added in version 0.20.0.

optree.integrations.attrs.mutable

Alias for define().

optree.integrations.attrs.make_class(name, attrs, /, *, namespace, **kwargs)[source]

Create a new attrs class and register it as a pytree node.

This is a wrapper around attrs.make_class() that also registers the class as a pytree node.

Parameters:
  • name (str) – The name for the new class.

  • attrs (Any) – A list of names or a dictionary of mappings of names to attrs.field() calls.

  • namespace (str) – The registry namespace used for the PyTree registration.

  • **kwargs (optional) – Optional keyword arguments passed to attrs.make_class().

Returns:

A new attrs class registered as a pytree node.

Return type:

type

Added in version 0.20.0.

optree.integrations.attrs.register_node(cls=None, /, *, namespace=None)[source]

Register an existing attrs class as a pytree node.

This function takes an existing attrs.define()-decorated class and registers it as a pytree node. It can be used as a direct function call or as a decorator.

Fields with metadata['pytree_node'] set to True (or not set, defaulting to True) are treated as children, while init fields with metadata['pytree_node'] set to False are treated as metadata.

Usage:

# Direct function call
register_node(Point, namespace='my-namespace')

# As a decorator
@register_node(namespace='my-namespace')
@attrs.define
class Point:
    x: float
    y: float
Parameters:
  • cls (type, optional) – An existing attrs-decorated class. If None, return a decorator.

  • namespace (str) – The registry namespace used for the PyTree registration.

Returns:

The same class, now registered as a pytree node, or a decorator function.

Return type:

type or callable

Added in version 0.20.0.

class optree.integrations.attrs.AttrsEntry(entry, type, kind)[source]

Bases: GetAttrEntry

A path entry class for attrs classes.


Integration for JAX

tree_ravel(tree, /[, is_leaf, none_is_leaf, ...])

Ravel (flatten) a pytree of arrays down to a 1D array.

optree.integrations.jax.tree_ravel(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]

Ravel (flatten) a pytree of arrays down to a 1D array.

>>> tree = {
...     'layer1': {
...         'weight': jnp.arange(0, 6, dtype=jnp.float32).reshape((2, 3)),
...         'bias': jnp.arange(6, 8, dtype=jnp.float32).reshape((2,)),
...     },
...     'layer2': {
...         'weight': jnp.arange(8, 10, dtype=jnp.float32).reshape((1, 2)),
...         'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,)),
...     },
... }
>>> tree
{
    'layer1': {
        'weight': Array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': Array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': Array([[8., 9.]], dtype=float32),
        'bias': Array([10.], dtype=float32)
    }
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
Array([ 6.,  7.,  0.,  1.,  2.,  3.,  4.,  5., 10.,  8.,  9.], dtype=float32)
>>> unravel_func(flat)
{
    'layer1': {
        'weight': Array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': Array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': Array([[8., 9.]], dtype=float32),
        'bias': Array([10.], dtype=float32)
    }
}
Parameters:
  • tree (pytree) – a pytree of arrays and scalars to ravel.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

tuple[Array, Callable[[Array], ArrayTree]]

Returns:

A pair (array, unravel_func) where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the input tree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.


Integration for NumPy

tree_ravel(tree, /[, is_leaf, none_is_leaf, ...])

Ravel (flatten) a pytree of arrays down to a 1D array.

optree.integrations.numpy.tree_ravel(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]

Ravel (flatten) a pytree of arrays down to a 1D array.

>>> tree = {
...     'layer1': {
...         'weight': np.arange(0, 6, dtype=np.float32).reshape((2, 3)),
...         'bias': np.arange(6, 8, dtype=np.float32).reshape((2,)),
...     },
...     'layer2': {
...         'weight': np.arange(8, 10, dtype=np.float32).reshape((1, 2)),
...         'bias': np.arange(10, 11, dtype=np.float32).reshape((1,)),
...     },
... }
>>> tree
{
    'layer1': {
        'weight': array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': array([[8., 9.]], dtype=float32),
        'bias': array([10.], dtype=float32)
    }
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
array([ 6.,  7.,  0.,  1.,  2.,  3.,  4.,  5., 10.,  8.,  9.], dtype=float32)
>>> unravel_func(flat)
{
    'layer1': {
        'weight': array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': array([[8., 9.]], dtype=float32),
        'bias': array([10.], dtype=float32)
    }
}
Parameters:
  • tree (pytree) – a pytree of arrays and scalars to ravel.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

tuple[ndarray, Callable[[ndarray], ArrayTree]]

Returns:

A pair (array, unravel_func) where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the input tree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.


Integration for PyTorch

tree_ravel(tree, /[, is_leaf, none_is_leaf, ...])

Ravel (flatten) a pytree of tensors down to a 1D tensor.

optree.integrations.torch.tree_ravel(tree, /, is_leaf=None, *, none_is_leaf=False, namespace='')[source]

Ravel (flatten) a pytree of tensors down to a 1D tensor.

>>> tree = {
...     'layer1': {
...         'weight': torch.arange(0, 6, dtype=torch.float64).reshape((2, 3)),
...         'bias': torch.arange(6, 8, dtype=torch.float64).reshape((2,)),
...     },
...     'layer2': {
...         'weight': torch.arange(8, 10, dtype=torch.float64).reshape((1, 2)),
...         'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,)),
...     },
... }
>>> tree
{
    'layer1': {
        'weight': tensor([[0., 1., 2.],
                          [3., 4., 5.]], dtype=torch.float64),
        'bias': tensor([6., 7.], dtype=torch.float64)
    },
    'layer2': {
        'weight': tensor([[8., 9.]], dtype=torch.float64),
        'bias': tensor([10.], dtype=torch.float64)
    }
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
tensor([ 6.,  7.,  0.,  1.,  2.,  3.,  4.,  5., 10.,  8.,  9.], dtype=torch.float64)
>>> unravel_func(flat)
{
    'layer1': {
        'weight': tensor([[0., 1., 2.],
                          [3., 4., 5.]], dtype=torch.float64),
        'bias': tensor([6., 7.], dtype=torch.float64)
    },
    'layer2': {
        'weight': tensor([[8., 9.]], dtype=torch.float64),
        'bias': tensor([10.], dtype=torch.float64)
    }
}
Parameters:
  • tree (pytree) – a pytree of tensors to ravel.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

tuple[Tensor, Callable[[Tensor], TensorTree]]

Returns:

A pair (tensor, unravel_func) where the first element is a 1D tensor representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D tensor of the same length back to a pytree of the same structure as the input tree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty tensor of the default dtype is returned in the first component of the output.