-
Notifications
You must be signed in to change notification settings - Fork 902
Description
🐛 Describe the bug
Arm RewriteConvPass fails on delegated non-fuseable conv -> relu -> cat branch (INT+FP)
Summary
There is an Arm backend bug in the quantized lowering pipeline for delegated conv -> relu/clamp branches when the activation is not fuseable.
In this case:
- the branch is still delegated under
TOSA-1.0+INT+FP FoldAndAnnotateQParamsPassleaves:conv.input_qparamspopulatedconv.output_qparamsemptyclamp.output_qparamspopulated
RewriteConvPassthen assumes the conv itself ownsoutput_qparamsand crashes
Observed failure:
ValueError: RewriteConvPass: No output quantization parameter found in node tosa_conv2d_default
original_aten=aten.convolution.default
Why this happens
Arm quantization annotation intentionally treats conv -> relu/hardtanh as:
- conv: quantized inputs
- activation: quantized output
That is fine if the activation can be fused into the conv path.
However, Arm's quantized-activation fusion only accepts activations whose output quantization satisfies:
zero_point == qmin
When the activation output is affine and ends up with zp != qmin, the activation is not fuseable.
The branch then stays as:
conv -> clamp -> quantize
Later, FoldAndAnnotateQParamsPass derives qparams from local graph structure:
input_qparamsfrom incoming DQ nodesoutput_qparamsfrom direct outgoing Q users
So the metadata becomes:
conv: input_qparams present, output_qparams missing
clamp: output_qparams present
Then RewriteConvPass rewrites the conv and unconditionally requires:
get_output_qparams(conv)[0]which crashes.
Minimal Repro
This tiny repro is enough:
import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.backends.arm.quantizer.arm_quantizer import (
VgfQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
class TinyConvReluCat(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(4, 4, 3, padding=1)
self.conv2 = nn.Conv2d(8, 4, 1)
with torch.no_grad():
for param in self.parameters():
param.uniform_(-0.1, 0.1)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
relu_out = F.relu(self.conv1(x))
merged = torch.cat((relu_out, y), dim=1)
return self.conv2(merged)
torch.manual_seed(0)
model = TinyConvReluCat().eval()
x = torch.rand(1, 4, 16, 16)
y = torch.rand(1, 4, 16, 16) - 0.065
compile_spec = VgfCompileSpec("TOSA-1.0+INT+FP")
quantizer = VgfQuantizer(compile_spec)
quantizer.set_global(
get_symmetric_quantization_config(
is_per_channel=True,
act_qmin=-127,
act_qmax=127,
weight_qmin=-127,
weight_qmax=127,
)
)
exported = torch.export.export(model, (x, y)).module(check_guards=False)
quantized = quantizer.quantize_with_submodules(exported, [(x, y)])
for node in quantized.graph.nodes:
if node.op == "call_function" and "quantize_per_tensor" in str(node.target):
source = node.args[0]
if getattr(source, "name", None) == "relu":
print(f"relu output qparams: scale={float(node.args[1])}, zp={int(node.args[2])}")
break
exported_q = torch.export.export(quantized, (x, y))
to_edge_transform_and_lower(
exported_q,
partitioner=[VgfPartitioner(compile_spec)],
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)Why the cat is needed
The cat is not incidental. It is what makes the first relu output participate in a shared downstream activation quantization domain instead of staying in the usual fuseable zp == qmin case.
In this tiny repro:
- branch 1 is
conv1 -> relu - branch 2 is the sibling tensor
y - both feed
cat catuses a shared quantization domain for its inputs
By shifting the calibration distribution of y, the observed qparams for the shared cat domain move away from the fuseable case. In the repro, that gives:
relu output zp = -111
qmin = -127
That makes the activation non-fuseable for Arm's conv + activation fusion logic.
Without the cat, a plain conv -> relu -> conv chain does not reproduce this bug in the same way:
- the first
reluusually keeps the normal fuseable output qparams - the backend can treat
conv -> reluas a fuseable quantized activation pattern - the first branch lowers successfully
So the cat is needed here because it is the smallest way to force the first activation into a non-fuseable shared activation domain while keeping the rest of the graph tiny.
Repro behavior
This repro deterministically gives the relu output:
relu output qparams: scale=0.003929885104298592, zp=-111
Since:
zp = -111
qmin = -127
the activation is not fuseable.
Lowering then fails with:
ValueError: RewriteConvPass: No output quantization parameter found in node tosa_conv2d_default
original_aten=aten.convolution.default
Important control case
The same tiny model with TOSA-1.0+INT does not reproduce this bug.
Reason:
- in pure
INTmode, Arm's partitioner applies stricter quantized-support checks - the non-fuseable first
conv -> reluisland is rejected during partitioning - so that branch never reaches
RewriteConvPass
This is why INT+FP is needed for the minimal repro: it allows the problematic branch to be delegated and lowered.
Root cause
This is an Arm backend pass-pipeline bug:
- quantization annotation intentionally places output quantization on the activation, not the conv
- the activation is non-fuseable because
zp != qmin FoldAndAnnotateQParamsPassleaves output qparams on the activation/clamp, not the convRewriteConvPassincorrectly assumes every quantized conv ownsoutput_qparams
Expected behavior
One of these should happen:
RewriteConvPassshould handleconv -> quantized clampcorrectly by looking through the activation for output qparams.- Earlier passes should ensure quantized conv nodes always retain compatible
output_qparamseven when followed by a non-fuseable quantized activation. - Such branches should be rejected consistently before reaching
RewriteConvPass.
Actual behavior
The branch is delegated and reaches RewriteConvPass, but the conv has no output_qparams, so lowering crashes.
Versions
executorch==1.2.0.dev20260305+cpu
torch==2.10.0
torchao==0.15.0
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell