Skip to content

Add DecomposeGruPass for ARM backend (#17137)#17137

Open
apullin wants to merge 1 commit intopytorch:mainfrom
apullin:export-D92058313
Open

Add DecomposeGruPass for ARM backend (#17137)#17137
apullin wants to merge 1 commit intopytorch:mainfrom
apullin:export-D92058313

Conversation

@apullin
Copy link
Contributor

@apullin apullin commented Feb 3, 2026

Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
h_t = n_t + z_t * (h_{t-1} - n_t)

Features:

  • Multi-layer GRU support
  • Bidirectional GRU support
  • With/without bias
  • batch_first support
  • Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313

@apullin apullin requested a review from digantdesai as a code owner February 3, 2026 07:33
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17137

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit e4f7631 with merge base 2a68e74 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2026
@github-actions
Copy link

github-actions bot commented Feb 3, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
@apullin apullin force-pushed the export-D92058313 branch 2 times, most recently from 9436714 to 2c403e6 Compare February 3, 2026 23:33
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Feb 3, 2026

@apullin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92058313.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)
[Confucius Session](https://www.internalfb.com/confucius?host=25384.od.fbinfra.net&port=8086&tab=Chat&session_id=527ee564-00d3-11f1-a194-8754b726bc51&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=527ee564-00d3-11f1-a194-8754b726bc51&tab=Trace)

Differential Revision: D92058313
@zingo zingo changed the title Add DecomposeGruPass for ARM backend Arm backend: Add DecomposeGruPass Feb 5, 2026
@zingo zingo added partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm ciflow/trunk labels Feb 5, 2026
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 5, 2026

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 25, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 27, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
@apullin
Copy link
Contributor Author

apullin commented Feb 27, 2026

@AdrianLundell @martinlsm I think we re seeking an approval from someone inside ARM for a PR like this. Please have a look.

@AdrianLundell
Copy link
Collaborator

AdrianLundell commented Feb 27, 2026

I'm a bit confused it seems like this is added also in #17140? And did you look into what @gggekov said about using torch._decomp.get_decompositions?

@apullin
Copy link
Contributor Author

apullin commented Feb 27, 2026

added also in #17140
@AdrianLundell
This is my mistake - I have several changes in a stack, and PR's got exported for each one, although they are 3 instances of the same technique, adding explicte decomp passes for common RNN layers.

See comment here:
#17140 (comment)

If we want to hit all 3, yes, #17140 would have that. instead of #17137 (this) and #17139 separately. Please advise.

@apullin apullin force-pushed the export-D92058313 branch from 1f0686e to 85801e1 Compare March 2, 2026 16:24
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 2, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 2, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 2, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 2, 2026
Summary:
Pull Request resolved: pytorch#17137

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
@apullin apullin force-pushed the export-D92058313 branch from 85801e1 to 135efca Compare March 2, 2026 16:27
@JakeStevens
Copy link
Contributor

per @gggekov comment on a different pull request:

apologies for the radio silence over the last weeks. Happy to go with your approach of manually decomposing rather than the get_decompositions route. I see you have a few failing tests and a merge conflict, could you resolve these ? Your patch looks good to me, I haven't verified the numerics of the decomposition, I assume they must be alright since you pass the unit tests. It would be great to have the LSTM, GRU & RNN available in the Arm backend.

based on this and internal testing LGTM

Copy link
Contributor

@JakeStevens JakeStevens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to rebase and resolve merge conflicts

@meta-codesync meta-codesync bot changed the title Arm backend: Add DecomposeGruPass Add DecomposeGruPass for ARM backend (#17137) Mar 24, 2026
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:
Pull Request resolved: pytorch#17137

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
@apullin
Copy link
Contributor Author

apullin commented Mar 24, 2026

you need to rebase and resolve merge conflicts

Done now, and the other failed test is expected to be green now. However - do we want to approve this one, or #17140, which includes the RNN & LSTM support, following the same pattern?

apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
Summary:
Pull Request resolved: pytorch#17137

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
@AdrianLundell
Copy link
Collaborator

I would say go with #17140 and get everything at the same time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants