Integration with functools
|
A version of |
|
Traversal through a pytree and reduce the leaves in left-to-right depth-first order. |
- class optree.functools.partial(func: Callable[[...], Any], *args: T, **keywords: T)[source]
Bases:
partial
,CustomTreeNode
[T
]A version of
functools.partial()
that works in pytrees.Use it for partial function evaluation in a way that is compatible with transformations, e.g.,
partial(func, *args, **kwargs)
.(You need to explicitly opt-in to this behavior because we did not want to give
functools.partial()
different semantics than normal function closures.)For example, here is a basic usage of
partial
in a manner similar tofunctools.partial()
:>>> import operator >>> import torch >>> add_one = partial(operator.add, torch.ones(())) >>> add_one(torch.tensor([[1, 2], [3, 4]])) tensor([[2., 3.], [4., 5.]])
Pytree compatibility means that the resulting partial function can be passed as an argument within tree-map functions, which is not possible with a standard
functools.partial()
function:>>> def call_func_on_cuda(f, *args, **kwargs): ... f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs)) ... return f(*args, **kwargs) ... >>> >>> tree_map(lambda t: t.cuda(), add_one) optree.functools.partial(<built-in function add>, tensor(1., device='cuda:0')) >>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]])) tensor([[2., 3.], [4., 5.]], device='cuda:0')
Passing zero arguments to
partial
effectively wraps the original function, making it a valid argument in tree-map functions:>>> >>> call_func_on_cuda(partial(torch.add), torch.tensor(1), torch.tensor(2)) tensor(3, device='cuda:0')
Had we passed
operator.add()
tocall_func_on_cuda
directly, it would have resulted in aTypeError
orAttributeError
.Create a new
partial
instance.
- optree.functools.reduce(func, tree, initial=<MISSING>, *, is_leaf=None, none_is_leaf=False, namespace='')
Traversal through a pytree and reduce the leaves in left-to-right depth-first order.
See also
tree_leaves()
andtree_sum()
.>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)}) 6 >>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default 6 >>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) 3 >>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) None
- Parameters:
func (callable) – A function that takes two arguments and returns a value of the same type.
tree (pytree) – A pytree to be traversed.
initial (object, optional) – An initial value to be used for the reduction. If not provided, the first leaf value is used as the initial value.
is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with
True
stopping the traversal and the whole subtree being treated as a leaf, andFalse
indicating the flattening should traverse the current object.none_is_leaf (bool, optional) – Whether to treat
None
as a leaf. IfFalse
,None
is a non-leaf node with arity 0. ThusNone
is contained in the treespec rather than in the leaves list andNone
will be remain in the result pytree. (default:False
)namespace (str, optional) – The registry namespace used for custom pytree node types. (default:
''
, i.e., the global namespace)
- Return type:
TypeVar
(T
)- Returns:
The result of reducing the leaves of the pytree using
func
.