Integration with functools

partial(func, *args, **keywords)

A version of functools.partial() that works in pytrees.

reduce(func, tree[, initial, is_leaf, ...])

Traversal through a pytree and reduce the leaves in left-to-right depth-first order.

class optree.functools.partial(func: Callable[[...], Any], *args: T, **keywords: T)[source]

Bases: partial, CustomTreeNode[T]

A version of functools.partial() that works in pytrees.

Use it for partial function evaluation in a way that is compatible with transformations, e.g., partial(func, *args, **kwargs).

(You need to explicitly opt-in to this behavior because we did not want to give functools.partial() different semantics than normal function closures.)

For example, here is a basic usage of partial in a manner similar to functools.partial():

>>> import operator
>>> import torch
>>> add_one = partial(operator.add, torch.ones(()))
>>> add_one(torch.tensor([[1, 2], [3, 4]]))
tensor([[2., 3.],
        [4., 5.]])

Pytree compatibility means that the resulting partial function can be passed as an argument within tree-map functions, which is not possible with a standard functools.partial() function:

>>> def call_func_on_cuda(f, *args, **kwargs):
...     f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs))
...     return f(*args, **kwargs)
...
>>> 
>>> tree_map(lambda t: t.cuda(), add_one)
optree.functools.partial(<built-in function add>, tensor(1., device='cuda:0'))
>>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]]))
tensor([[2., 3.],
        [4., 5.]], device='cuda:0')

Passing zero arguments to partial effectively wraps the original function, making it a valid argument in tree-map functions:

>>> 
>>> call_func_on_cuda(partial(torch.add), torch.tensor(1), torch.tensor(2))
tensor(3, device='cuda:0')

Had we passed operator.add() to call_func_on_cuda directly, it would have resulted in a TypeError or AttributeError.

Create a new partial instance.

optree.functools.reduce(func, tree, initial=<MISSING>, *, is_leaf=None, none_is_leaf=False, namespace='')

Traversal through a pytree and reduce the leaves in left-to-right depth-first order.

See also tree_leaves() and tree_sum().

>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)})
6
>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3})  # `None` is a non-leaf node with arity 0 by default
6
>>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3})
3
>>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True)
None
Parameters:
  • func (callable) – A function that takes two arguments and returns a value of the same type.

  • tree (pytree) – A pytree to be traversed.

  • initial (object, optional) – An initial value to be used for the reduction. If not provided, the first leaf value is used as the initial value.

  • is_leaf (callable, optional) – An optionally specified function that will be called at each flattening step. It should return a boolean, with True stopping the traversal and the whole subtree being treated as a leaf, and False indicating the flattening should traverse the current object.

  • none_is_leaf (bool, optional) – Whether to treat None as a leaf. If False, None is a non-leaf node with arity 0. Thus None is contained in the treespec rather than in the leaves list and None will be remain in the result pytree. (default: False)

  • namespace (str, optional) – The registry namespace used for custom pytree node types. (default: '', i.e., the global namespace)

Return type:

TypeVar(T)

Returns:

The result of reducing the leaves of the pytree using func.