Integration with Third-Party Libraries

Integration for JAX

tree_ravel(tree[, is_leaf, none_is_leaf, ...])

Ravel (flatten) a pytree of arrays down to a 1D array.

optree.integration.jax.tree_ravel(tree, is_leaf=None, *, none_is_leaf=False, namespace='')[source]

Ravel (flatten) a pytree of arrays down to a 1D array.

>>> tree = {
...     'layer1': {
...         'weight': jnp.arange(0, 6, dtype=jnp.float32).reshape((2, 3)),
...         'bias': jnp.arange(6, 8, dtype=jnp.float32).reshape((2,)),
...     },
...     'layer2': {
...         'weight': jnp.arange(8, 10, dtype=jnp.float32).reshape((1, 2)),
...         'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,))
...     },
... }
>>> tree  
{
    'layer1': {
        'weight': Array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': Array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': Array([[8., 9.]], dtype=float32),
        'bias': Array([10.], dtype=float32)
    }
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
Array([ 6.,  7.,  0.,  1.,  2.,  3.,  4.,  5., 10.,  8.,  9.], dtype=float32)
>>> unravel_func(flat)  
{
    'layer1': {
        'weight': Array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': Array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': Array([[8., 9.]], dtype=float32),
        'bias': Array([10.], dtype=float32)
    }
}
Parameters:
  • tree (pytree) – a pytree of arrays and scalars to ravel.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will be remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

tuple[Array, Callable[[Array], Union[Array, Tuple[Union[Array, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]], ...], List[Union[Array, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]], Dict[Any, Union[Array, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]], Deque[Union[Array, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]], CustomTreeNode[Union[Array, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]]]]]

Returns:

A pair (array, unravel_func) where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the input tree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.


Integration for NumPy

tree_ravel(tree[, is_leaf, none_is_leaf, ...])

Ravel (flatten) a pytree of arrays down to a 1D array.

optree.integration.numpy.tree_ravel(tree, is_leaf=None, *, none_is_leaf=False, namespace='')[source]

Ravel (flatten) a pytree of arrays down to a 1D array.

>>> tree = {
...     'layer1': {
...         'weight': np.arange(0, 6, dtype=np.float32).reshape((2, 3)),
...         'bias': np.arange(6, 8, dtype=np.float32).reshape((2,)),
...     },
...     'layer2': {
...         'weight': np.arange(8, 10, dtype=np.float32).reshape((1, 2)),
...         'bias': np.arange(10, 11, dtype=np.float32).reshape((1,))
...     },
... }
>>> tree  
{
    'layer1': {
        'weight': array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': array([[8., 9.]], dtype=float32),
        'bias': array([10.], dtype=float32)
    }
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
array([ 6.,  7.,  0.,  1.,  2.,  3.,  4.,  5., 10.,  8.,  9.], dtype=float32)
>>> unravel_func(flat)  
{
    'layer1': {
        'weight': array([[0., 1., 2.],
                         [3., 4., 5.]], dtype=float32),
        'bias': array([6., 7.], dtype=float32)
    },
    'layer2': {
        'weight': array([[8., 9.]], dtype=float32),
        'bias': array([10.], dtype=float32)
    }
}
Parameters:
  • tree (pytree) – a pytree of arrays and scalars to ravel.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will be remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

tuple[ndarray, Callable[[ndarray], Union[ndarray, Tuple[Union[ndarray, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]], ...], List[Union[ndarray, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]], Dict[Any, Union[ndarray, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]], Deque[Union[ndarray, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]], CustomTreeNode[Union[ndarray, Tuple[ArrayTree, ...], List[ArrayTree], Dict[Any, ArrayTree], Deque[ArrayTree], CustomTreeNode[ArrayTree]]]]]]

Returns:

A pair (array, unravel_func) where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the input tree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.


Integration for PyTorch

tree_ravel(tree[, is_leaf, none_is_leaf, ...])

Ravel (flatten) a pytree of tensors down to a 1D tensor.

optree.integration.torch.tree_ravel(tree, is_leaf=None, *, none_is_leaf=False, namespace='')[source]

Ravel (flatten) a pytree of tensors down to a 1D tensor.

>>> tree = {
...     'layer1': {
...         'weight': torch.arange(0, 6, dtype=torch.float64).reshape((2, 3)),
...         'bias': torch.arange(6, 8, dtype=torch.float64).reshape((2,)),
...     },
...     'layer2': {
...         'weight': torch.arange(8, 10, dtype=torch.float64).reshape((1, 2)),
...         'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,))
...     },
... }
>>> tree  
{
    'layer1': {
        'weight': tensor([[0., 1., 2.],
                          [3., 4., 5.]], dtype=torch.float64),
        'bias': tensor([6., 7.], dtype=torch.float64)
    },
    'layer2': {
        'weight': tensor([[8., 9.]], dtype=torch.float64),
        'bias': tensor([10.], dtype=torch.float64)
    }
}
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
tensor([ 6.,  7.,  0.,  1.,  2.,  3.,  4.,  5., 10.,  8.,  9.], dtype=torch.float64)
>>> unravel_func(flat)  
{
    'layer1': {
        'weight': tensor([[0., 1., 2.],
                          [3., 4., 5.]], dtype=torch.float64),
        'bias': tensor([6., 7.], dtype=torch.float64)
    },
    'layer2': {
        'weight': tensor([[8., 9.]], dtype=torch.float64),
        'bias': tensor([10.], dtype=torch.float64)
    }
}
Parameters:
  • tree (pytree) – a pytree of tensors to ravel.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will be remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

tuple[Tensor, Callable[[Tensor], Union[Tensor, Tuple[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]], ...], List[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Dict[Any, Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], Deque[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]], CustomTreeNode[Union[Tensor, Tuple[TensorTree, ...], List[TensorTree], Dict[Any, TensorTree], Deque[TensorTree], CustomTreeNode[TensorTree]]]]]]

Returns:

A pair (tensor, unravel_func) where the first element is a 1D tensor representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D tensor of the same length back to a pytree of the same structure as the input tree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty tensor of the default dtype is returned in the first component of the output.