Integration with Third-Party Libraries
Integration for JAX
|
Ravel (flatten) a pytree of arrays down to a 1D array. |
- optree.integration.jax.tree_ravel(tree, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = { ... 'layer1': { ... 'weight': jnp.arange(0, 6, dtype=jnp.float32).reshape((2, 3)), ... 'bias': jnp.arange(6, 8, dtype=jnp.float32).reshape((2,)), ... }, ... 'layer2': { ... 'weight': jnp.arange(8, 10, dtype=jnp.float32).reshape((1, 2)), ... 'bias': jnp.arange(10, 11, dtype=jnp.float32).reshape((1,)) ... }, ... } >>> tree { 'layer1': { 'weight': Array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': Array([6., 7.], dtype=float32) }, 'layer2': { 'weight': Array([[8., 9.]], dtype=float32), 'bias': Array([10.], dtype=float32) } } >>> flat, unravel_func = tree_ravel(tree) >>> flat Array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32) >>> unravel_func(flat) { 'layer1': { 'weight': Array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': Array([6., 7.], dtype=float32) }, 'layer2': { 'weight': Array([[8., 9.]], dtype=float32), 'bias': Array([10.], dtype=float32) } }
- Parameters:
tree (pytree) – a pytree of arrays and scalars to ravel.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
True
stopping the traversal and the whole subtree being treated as a leaf, andFalse
indicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
None
as a leaf. IfFalse
,None
is a non-leaf node with arity 0. ThusNone
is contained in the treespec rather than in the leaves list andNone
will be remain in the result pytree. (default:False
)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
''
, i.e., the global namespace)
- Return type:
tuple
[Array
,Callable
[[Array
],Union
[Array
,Tuple
[Union
[Array
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]],...
],List
[Union
[Array
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]],Dict
[Any
,Union
[Array
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]],Deque
[Union
[Array
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]],CustomTreeNode
[Union
[Array
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]]]]]- Returns:
A pair
(array, unravel_func)
where the first element is a 1D array representing the flattened and concatenated leaf values, withdtype
determined by promoting thedtype
s of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the inputtree
. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.
Integration for NumPy
|
Ravel (flatten) a pytree of arrays down to a 1D array. |
- optree.integration.numpy.tree_ravel(tree, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = { ... 'layer1': { ... 'weight': np.arange(0, 6, dtype=np.float32).reshape((2, 3)), ... 'bias': np.arange(6, 8, dtype=np.float32).reshape((2,)), ... }, ... 'layer2': { ... 'weight': np.arange(8, 10, dtype=np.float32).reshape((1, 2)), ... 'bias': np.arange(10, 11, dtype=np.float32).reshape((1,)) ... }, ... } >>> tree { 'layer1': { 'weight': array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': array([6., 7.], dtype=float32) }, 'layer2': { 'weight': array([[8., 9.]], dtype=float32), 'bias': array([10.], dtype=float32) } } >>> flat, unravel_func = tree_ravel(tree) >>> flat array([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=float32) >>> unravel_func(flat) { 'layer1': { 'weight': array([[0., 1., 2.], [3., 4., 5.]], dtype=float32), 'bias': array([6., 7.], dtype=float32) }, 'layer2': { 'weight': array([[8., 9.]], dtype=float32), 'bias': array([10.], dtype=float32) } }
- Parameters:
tree (pytree) – a pytree of arrays and scalars to ravel.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
True
stopping the traversal and the whole subtree being treated as a leaf, andFalse
indicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
None
as a leaf. IfFalse
,None
is a non-leaf node with arity 0. ThusNone
is contained in the treespec rather than in the leaves list andNone
will be remain in the result pytree. (default:False
)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
''
, i.e., the global namespace)
- Return type:
tuple
[ndarray
,Callable
[[ndarray
],Union
[ndarray
,Tuple
[Union
[ndarray
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]],...
],List
[Union
[ndarray
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]],Dict
[Any
,Union
[ndarray
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]],Deque
[Union
[ndarray
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]],CustomTreeNode
[Union
[ndarray
,Tuple
[ArrayTree,...
],List
[ArrayTree],Dict
[Any
, ArrayTree],Deque
[ArrayTree],CustomTreeNode
[ArrayTree]]]]]]- Returns:
A pair
(array, unravel_func)
where the first element is a 1D array representing the flattened and concatenated leaf values, withdtype
determined by promoting thedtype
s of leaf values, and the second element is a callable for unflattening a 1D array of the same length back to a pytree of the same structure as the inputtree
. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of the default dtype is returned in the first component of the output.
Integration for PyTorch
|
Ravel (flatten) a pytree of tensors down to a 1D tensor. |
- optree.integration.torch.tree_ravel(tree, is_leaf=None, *, none_is_leaf=False, namespace='')[source]
Ravel (flatten) a pytree of tensors down to a 1D tensor.
>>> tree = { ... 'layer1': { ... 'weight': torch.arange(0, 6, dtype=torch.float64).reshape((2, 3)), ... 'bias': torch.arange(6, 8, dtype=torch.float64).reshape((2,)), ... }, ... 'layer2': { ... 'weight': torch.arange(8, 10, dtype=torch.float64).reshape((1, 2)), ... 'bias': torch.arange(10, 11, dtype=torch.float64).reshape((1,)) ... }, ... } >>> tree { 'layer1': { 'weight': tensor([[0., 1., 2.], [3., 4., 5.]], dtype=torch.float64), 'bias': tensor([6., 7.], dtype=torch.float64) }, 'layer2': { 'weight': tensor([[8., 9.]], dtype=torch.float64), 'bias': tensor([10.], dtype=torch.float64) } } >>> flat, unravel_func = tree_ravel(tree) >>> flat tensor([ 6., 7., 0., 1., 2., 3., 4., 5., 10., 8., 9.], dtype=torch.float64) >>> unravel_func(flat) { 'layer1': { 'weight': tensor([[0., 1., 2.], [3., 4., 5.]], dtype=torch.float64), 'bias': tensor([6., 7.], dtype=torch.float64) }, 'layer2': { 'weight': tensor([[8., 9.]], dtype=torch.float64), 'bias': tensor([10.], dtype=torch.float64) } }
- Parameters:
tree (pytree) – a pytree of tensors to ravel.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
True
stopping the traversal and the whole subtree being treated as a leaf, andFalse
indicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
None
as a leaf. IfFalse
,None
is a non-leaf node with arity 0. ThusNone
is contained in the treespec rather than in the leaves list andNone
will be remain in the result pytree. (default:False
)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
''
, i.e., the global namespace)
- Return type:
tuple
[Tensor
,Callable
[[Tensor
],Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]],...
],List
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Dict
[Any
,Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],Deque
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]],CustomTreeNode
[Union
[Tensor
,Tuple
[TensorTree,...
],List
[TensorTree],Dict
[Any
, TensorTree],Deque
[TensorTree],CustomTreeNode
[TensorTree]]]]]]- Returns:
A pair
(tensor, unravel_func)
where the first element is a 1D tensor representing the flattened and concatenated leaf values, withdtype
determined by promoting thedtype
s of leaf values, and the second element is a callable for unflattening a 1D tensor of the same length back to a pytree of the same structure as the inputtree
. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty tensor of the default dtype is returned in the first component of the output.