Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@
class ArmPass(ExportPass):
"""Base class for Arm passes."""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
if getattr(cls, "targeted_ops", None) is not None:
return
# Only auto-discover targeted_ops for passes that use the standard
# call_operator() pattern. Passes that override call() use _TARGET_OPS
# for their own graph manipulation logic, not as a fast-copy declaration.
if "call" in cls.__dict__:
return
for attr in ("_TARGET_OPS", "_supported_ops"):
ops = getattr(cls, attr, None)
if ops:
cls.targeted_ops = set(ops) if not isinstance(ops, set) else ops # type: ignore[attr-defined]
return
edge = getattr(cls, "_EDGE_OPS", None)
aten = getattr(cls, "_ATEN_OPS", None)
if edge or aten:
cls.targeted_ops = {*(edge or ()), *(aten or ())} # type: ignore[attr-defined]

def __init__(self, tfa_pass: bool = False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.submodule_depth = 0
Expand Down Expand Up @@ -78,6 +97,34 @@ def get_name(pass_) -> str:
f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute."
)

def should_run(self, graph_module: GraphModule) -> bool:
"""Skip this pass if the graph contains none of its targeted ops.

Subclasses that define a ``targeted_ops`` class attribute (a set of
op overloads) get this check for free via inheritance. Passes
without ``targeted_ops`` always run (the default).

Recursively checks control flow submodules (cond/while_loop) so
passes are not incorrectly skipped when targeted ops are nested.

"""
targeted = getattr(self, "targeted_ops", None)
if targeted is None:
return True

from executorch.exir.graph_module import get_control_flow_submodules

def _has_targeted_op(gm: GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target in targeted:
return True
for _, submod, _ in get_control_flow_submodules(gm):
if _has_targeted_op(submod):
return True
return False

return _has_targeted_op(graph_module)

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if not updated:
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/_passes/cast_to_int32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass

from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -35,6 +33,8 @@ class Conv1dUnsqueezePass(ArmPass):
SizeAdjustInputPass,
}

targeted_ops = {exir_ops.edge.aten.convolution.default}

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.convolution.default:
return super().call_operator(op, args, kwargs, meta)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import cast, Set, Type

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
UnsqueezeBeforeRepeatPass,
Expand Down Expand Up @@ -58,6 +57,8 @@ class ConvertExpandCopyToRepeatPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass}

targeted_ops = {exir_ops.edge.aten.expand_copy.default}

expand_copy = exir_ops.edge.aten.expand_copy.default
repeat = exir_ops.edge.aten.repeat.default

Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_full_like_to_full_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -36,6 +35,8 @@ class ConvertFullLikeToFullPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}

targeted_ops = {exir_ops.edge.aten.full_like.default}

def call_operator(self, op, args, kwargs, meta):
if op not in [
exir_ops.edge.aten.full_like.default,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from typing import Sequence, Set, Tuple, Type

from executorch.backends.arm._passes.arm_pass import ArmPass

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

from torch._ops import OpOverload


Expand All @@ -35,6 +33,8 @@ class ConvertPermuteSingletonToViewPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = set(_PERMUTE_TARGETS)

def call_operator(self, op, args, kwargs, meta):
if op not in _PERMUTE_TARGETS:
return super().call_operator(op, args, kwargs, meta)
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class ConvertSplitToSlicePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
}

split_ops = (
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/convert_squeezes_to_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class ConvertSqueezesToViewPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass}

targeted_ops = {
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
}

def call_operator(self, op, args, kwargs, meta):
if op not in [
exir_ops.edge.aten.squeeze_copy.dims,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/convert_to_clamp_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from typing import Set, Tuple, Type

from executorch.backends.arm._passes import ArmPass

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
QuantizeClampArgumentsPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -32,6 +30,8 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
class ConvertToClampPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass}

targeted_ops = edge_operators

def call_operator(self, op, args, kwargs, meta):
if op not in edge_operators or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_acosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DecomposeAcoshPass(ArmPass):
MatchArgDtypePass,
}

targeted_ops = {edge_acosh_op}

def call_operator(self, op, args, kwargs, meta, updated=False):

if op is not edge_acosh_op:
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_avg_pool2d_pass import (
DecomposeAvgPool2dPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata

Expand Down Expand Up @@ -48,6 +46,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass}

targeted_ops = {*edge_ops, *aten_ops}

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta, updated)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_add_sub_alpha_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DecomposeAddSubAlphaPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {*_ADD_OPS, *_SUB_OPS}

def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
if op not in _ADD_OPS + _SUB_OPS:
return super().call_operator(op, args, kwargs, meta, updated)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/decompose_addmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
Expand Down Expand Up @@ -50,6 +49,8 @@ class DecomposeAddmmPass(ArmPass):
MatchArgDtypePass,
}

targeted_ops = {edge_addmm, aten_addmm}

def call_operator(self, op, args, kwargs, meta):
if op not in [edge_addmm, aten_addmm] or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
1 change: 0 additions & 1 deletion backends/arm/_passes/decompose_as_strided_copy_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict, Optional, Set, Tuple, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm.common.as_strided_utils import (
contiguous_strides,
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
Expand Down Expand Up @@ -72,6 +71,8 @@ class DecomposeAsinAndAcosPass(ArmPass):
ReplaceScalarWithTensorByProfilePass,
}

targeted_ops = {*edge_asin_op, *edge_acos_op}

def _build_polynomial(
self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str]
) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_asinh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DecomposeAsinhPass(ArmPass):
MatchArgDtypePass,
}

targeted_ops = {*edge_asinh_op}

def call_operator(self, op, args, kwargs, meta):
if op not in edge_asinh_op:
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_atan_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class DecomposeAtanPass(ArmPass):
ReplaceScalarWithTensorByProfilePass,
}

targeted_ops = {edge_atan}

def _rational_approximation(self, z, ops, meta):
"""Creates a (2,1) Padé approximation for atan(x) on [-1, 1]."""

Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_atanh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class DecomposeAtanhPass(ArmPass):
ReplaceScalarWithTensorByProfilePass,
}

targeted_ops = {edge_atanh}

def call_operator(self, op, args, kwargs, meta):
if op is not edge_atanh:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def get_decomposition(op) -> tuple:
class DecomposeAvgPool2dPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}

targeted_ops = {*edge_div_ops, *aten_div_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform(
meta
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_cosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class DecomposeCoshPass(ArmPass):
MatchArgDtypePass,
}

targeted_ops = {edge_cosh}

def call_operator(self, op, args, kwargs, meta, updated=False):
if op is not edge_cosh:
return super().call_operator(op, args, kwargs, meta, updated)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/decompose_cosine_similarity_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
)

from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
Expand Down Expand Up @@ -43,6 +42,8 @@ class DecomposeCosineSimilarityPass(ArmPass):
InsertTableOpsPass,
}

targeted_ops = {*torch_cosine_similarity}

def call_operator(self, op, args, kwargs, meta):
if op not in torch_cosine_similarity or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class DecomposeDivPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}

targeted_ops = {*edge_div_ops, *aten_div_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform(
meta
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_div_tensor_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DecomposeDivTensorModePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass}

targeted_ops = {*edge_div_mode_ops, *aten_div_mode_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in (
edge_div_mode_ops + aten_div_mode_ops
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_elu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DecomposeEluPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {*edge_elu_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in edge_elu_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_expm1_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class DecomposeExpm1Pass(ArmPass):
MatchArgRanksPass,
}

targeted_ops = {*edge_expm1_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in edge_expm1_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_floor_divide_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class DecomposeFloorDividePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}

targeted_ops = {*edge_floor_divide_ops, *aten_floor_divide_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_floor_divide_ops + aten_floor_divide_ops):
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_gelu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class DecomposeGeluPass(ArmPass):
MatchArgRanksPass,
}

targeted_ops = {*torch_gelu, *edge_gelu}

def call_operator(self, op, args, kwargs, meta):
if op not in torch_gelu + edge_gelu:
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_glu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class DecomposeGluPass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}

targeted_ops = {edge_glu, aten_glu}

def call_operator(self, op, args, kwargs, meta):
if op not in [edge_glu, aten_glu] or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
Loading
Loading