Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
QuantizeClampArgumentsPass,
)
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
from .fuse_constant_ops_pass import ( # noqa
ComputeConstantOpsAOTPass,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
DecorateFp32toInt32CastingPass,
FoldAndAnnotateQParamsPass,
FuseBatchNorm2dPass,
FuseConsecutiveConcatShapesPass,
FuseConsecutiveRescalesPass,
FuseConstantArgsPass,
FuseDuplicateUsersPass,
Expand Down Expand Up @@ -503,6 +504,7 @@ def _tosa_pipeline(
[
CastInt64BuffersToInt32Pass(exported_program),
FuseEqualPlaceholdersPass(exported_program),
FuseConsecutiveConcatShapesPass(),
ToTosaMemoryFormatPass(exported_program),
RemoveNoopPass(),
InsertRescalePass(),
Expand Down
63 changes: 63 additions & 0 deletions backends/arm/_passes/fuse_consecutive_concat_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import NodeMetadata, ProxyValue


class FuseConsecutiveConcatShapesPass(ArmPass):
"""This pass fuses consecutive tosa.CONCAT_SHAPE operations into a single
tosa.CONCAT_SHAPE operation with a flattened list of input shapes. E.g.
tosa.CONCAT_SHAPE([shape1, tosa.CONCAT_SHAPE([shape2, shape3]), shape4])
becomes tosa.CONCAT_SHAPE([shape1, shape2, shape3, shape4])

This is necessary in order for dim-order propagation to work correctly. E.g.
in the case of dim-order==(0, 2, 3, 1) we would need to permute input shapes
accordingly. This is much easier if the inputs are flattened.

"""

_passes_required_after = set()

def _to_proxy_value(
self, arg: ProxyValue | torch.fx.Node | Any
) -> ProxyValue | Any:
if isinstance(arg, ProxyValue):
return arg
if isinstance(arg, torch.fx.Node):
return ProxyValue(arg.meta["val"], self.tracer.proxy(arg))
return arg

def call_operator(
self,
op: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
meta: NodeMetadata,
updated: bool | None = False,
) -> ProxyValue:
if op != exir_ops.backend.tosa.CONCAT_SHAPE.default:
return super().call_operator(op, args, kwargs, meta)
arg_list = args[0]
new_arg_list: list[Any] = []
modified = False
for arg in arg_list:
if (
hasattr(arg, "node")
and arg.node.target == exir_ops.backend.tosa.CONCAT_SHAPE.default
):
new_arg_list.extend(
self._to_proxy_value(nested_arg) for nested_arg in arg.node.args[0]
)
modified = True
else:
new_arg_list.append(arg)
return super().call_operator(
op, (new_arg_list,), kwargs, meta, updated=modified
)
110 changes: 110 additions & 0 deletions backends/arm/test/passes/test_fuse_consecutive_concat_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import executorch.backends.arm.tosa.dialect # noqa: F401
from executorch.backends.arm._passes.fuse_consecutive_concat_shapes import (
FuseConsecutiveConcatShapesPass,
)
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult


def _graph_module_with_nested_concat():
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
builder = GraphBuilder()
const_0 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([0],)
)
const_1 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([1],)
)
const_2 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([2],)
)
const_3 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([3],)
)
inner = builder.call_operator(
exir_ops.backend.tosa.CONCAT_SHAPE.default, ([const_1, const_2],)
)
outer = builder.call_operator(
exir_ops.backend.tosa.CONCAT_SHAPE.default, ([const_0, inner, const_3],)
)
builder.output([outer])
return ExportPass().call(builder.get_graph_module()).graph_module


def _graph_module_with_flat_concat():
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
builder = GraphBuilder()
const_0 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([4],)
)
const_1 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([5],)
)
const_2 = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([6],)
)
outer = builder.call_operator(
exir_ops.backend.tosa.CONCAT_SHAPE.default, ([const_0, const_1, const_2],)
)
builder.output([outer])
return ExportPass().call(builder.get_graph_module()).graph_module


def _concat_shape_nodes(graph_module):
return [
node
for node in graph_module.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.backend.tosa.CONCAT_SHAPE.default
]


def _const_shape_values(shape_list_nodes):
return [node.args[0][0] for node in shape_list_nodes]


def _run_fuse_pass(graph_module: GraphModule):
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
result = FuseConsecutiveConcatShapesPass()(graph_module)
if isinstance(result, PassResult):
graph_module = result.graph_module
graph_module.graph.eliminate_dead_code()
return graph_module


def test_fuse_consecutive_concat_shapes_no_target_flattens_nested_concat_inputs():
graph_module = _graph_module_with_nested_concat()
graph_module = _run_fuse_pass(graph_module)

concat_nodes = _concat_shape_nodes(graph_module)
outer_concat = concat_nodes[-1]
outer_inputs = outer_concat.args[0]

assert len(concat_nodes) == 1
assert _const_shape_values(outer_inputs) == [0, 1, 2, 3]
assert all(
node.target == exir_ops.backend.tosa.CONST_SHAPE.default
for node in outer_inputs
)


def test_fuse_consecutive_concat_shapes_no_target_leaves_flat_concat_unchanged():
graph_module = _graph_module_with_flat_concat()
graph_module = _run_fuse_pass(graph_module)

concat_nodes = _concat_shape_nodes(graph_module)
outer_inputs = concat_nodes[-1].args[0]

assert len(concat_nodes) == 1
assert _const_shape_values(outer_inputs) == [4, 5, 6]
Loading