-
Notifications
You must be signed in to change notification settings - Fork 54
feat: add support for specifying a tuple of axis positions in expand_dims
#988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
In data-apis#354, a regression was introduced which reverted a change to the signature of `expand_dims`. Namely, the `axis` argument should not have been made optional and should not have had a default value. Ref: data-apis#331 Ref: data-apis#354
ev-br
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be very useful to add a comment from #760 (comment)
This behavior is semantically equivalent to calling expand_dims repeatedly with a single axis, only when the axes tuple is normalized to positive values using the final shape, is sorted, and contains no duplicates.
| If ``axis`` is a tuple, | ||
|
|
||
| - each entry of ``axis`` must resolve to a unique axis position. If an entry is a negative integer, the entry **must** resolve to a positive axis position according to the rules described above. | ||
| - if provided an invalid axis position, the function **must** raise an exception. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy raises AxisError, which derives from IndexError (which torch.unsqueeze raises) and ValueError (which jax.numpy raises). So short of adding AxisError with a prescribed inheritance hierarchy we cannot be more specific on what exception to raise.
|
@ev-br Added the desired note. I believe this is ready for another review. |
|
|
||
|
|
||
| def expand_dims(x: array, /, *, axis: int = 0) -> array: | ||
| def expand_dims(x: array, /, axis: int) -> array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SciPy doesn't look too badly impacted by the reversion, I think just https://github.com/scipy/scipy/blob/341152d40c3274c0e37068321cccfb08733e2707/scipy/signal/_filter_design.py#L87
|
Let's open an issue on merge of this to plan a deprecation over at https://data-apis.org/array-api-extra/generated/array_api_extra.expand_dims.html. I'm not sure exactly what strategy is appropriate, maybe good to discuss. |
| axis: Union[int, Tuple[int, ...]] | ||
| axis position(s) (zero-based). If ``axis`` is an integer, | ||
|
|
||
| - a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea: would it be clearer here to talk about valid indices in terms of the output dimensions? Then this would change to
- a valid axis position **must** reside on the semi-open interval ``[-M, M)` where
`M = x.ndim + 1` is the number of dimensions of the *output* array.
then the tuple version of this would be identical, except it would say M = ndim(x) + len(axis)
This PR:
expand_dims#760expand_dims. Namely, theaxisargument should not have been made optional and should not have had a default value. This regression had gone unnoticed until working on this PR and a patch has been backported to prior revisions of the standard.expand_dims, thus addressing RFC: add support for a tuple of axes inexpand_dims#760. The added guidance follows the steps outlined in RFC: add support for a tuple of axes inexpand_dims#760 (comment).Notes
array-api-compat.