Skip to content

Add DecomposeRnnPass for ARM backend (#17139)#17139

Open
apullin wants to merge 2 commits intopytorch:mainfrom
apullin:export-D92059152
Open

Add DecomposeRnnPass for ARM backend (#17139)#17139
apullin wants to merge 2 commits intopytorch:mainfrom
apullin:export-D92059152

Conversation

@apullin
Copy link
Copy Markdown
Contributor

@apullin apullin commented Feb 3, 2026

Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:

  • Multi-layer RNN support
  • Bidirectional RNN support
  • With/without bias
  • batch_first support
  • Both tanh and relu nonlinearities

Differential Revision: D92059152

@apullin apullin requested a review from digantdesai as a code owner February 3, 2026 07:33
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17139

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit 973c777 with merge base 7c79395 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 3, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
@apullin apullin force-pushed the export-D92059152 branch 2 times, most recently from 738855e to 8003c8a Compare February 3, 2026 23:22
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
@apullin apullin force-pushed the export-D92059152 branch 2 times, most recently from 8003c8a to 466c9ab Compare February 3, 2026 23:33
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Feb 3, 2026

@apullin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92059152.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
pytorch-bot bot pushed a commit that referenced this pull request Feb 6, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059152
@zingo zingo added the partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm label Feb 6, 2026
@zingo zingo changed the title Add DecomposeRnnPass for ARM backend Arm backend: Add DecomposeRnnPass Feb 6, 2026
@apullin
Copy link
Copy Markdown
Contributor Author

apullin commented Feb 6, 2026

@pytorchbot label "release notes: feature"

apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
@apullin apullin force-pushed the export-D92059152 branch 2 times, most recently from 8356482 to 991e144 Compare March 24, 2026 17:59
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:
Pull Request resolved: pytorch#17139

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
@apullin apullin force-pushed the export-D92059152 branch 2 times, most recently from 3591ca5 to 2ba998d Compare March 24, 2026 18:08
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:
Pull Request resolved: pytorch#17139

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:
Pull Request resolved: pytorch#17139

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 24, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 25, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 25, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
@apullin apullin force-pushed the export-D92059152 branch 2 times, most recently from ae79c65 to dc4781b Compare March 25, 2026 16:04
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 25, 2026
Summary:
Pull Request resolved: pytorch#17139

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 25, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 25, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 26, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 26, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
apullin pushed a commit to apullin/executorch that referenced this pull request Mar 26, 2026
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
Andrew Pullin and others added 2 commits March 27, 2026 08:50
Summary:
Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
Summary:
Pull Request resolved: pytorch#17139

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants