Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
- id: clang-format
stages: [commit, push, manual]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.255
rev: v0.0.257
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
12 changes: 6 additions & 6 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ def tree_reduce(
def tree_reduce(
func: Callable[[T, T], T],
tree: PyTree[T],
initializer: T = __MISSING,
initial: T = __MISSING,
*,
is_leaf: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
Expand All @@ -907,7 +907,7 @@ def tree_reduce(
def tree_reduce(
func,
tree,
initializer=__MISSING,
initial=__MISSING,
*,
is_leaf=None,
none_is_leaf=False,
Expand All @@ -929,8 +929,8 @@ def tree_reduce(
Args:
func (callable): A function that takes two arguments and returns a value of the same type.
tree (pytree): A pytree to be traversed.
initializer (object, optional): An initial value to be used for the reduction. If not
provided, the first leaf value is used as the initial value.
initial (object, optional): An initial value to be used for the reduction. If not provided,
the first leaf value is used as the initial value.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
Expand All @@ -946,9 +946,9 @@ def tree_reduce(
The result of reducing the leaves of the pytree using ``func``.
""" # pylint: disable=line-too-long
leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
if initializer is __MISSING:
if initial is __MISSING:
return functools.reduce(func, leaves)
return functools.reduce(func, leaves, initializer)
return functools.reduce(func, leaves, initial)


def tree_sum(
Expand Down
2 changes: 1 addition & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class PyTree(Generic[T]): # pylint: disable=too-few-public-methods
"""

@_tp_cache
def __class_getitem__( # noqa: C901
def __class_getitem__(
cls,
item: T | tuple[T] | tuple[T, str | None],
) -> TypeAlias:
Expand Down