Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2073,4 +2073,9 @@ def plan(self) -> ExecutionPlan:
self.module.meta["non_const_buffer_sizes"],
),
container_meta_type=self.container_meta_type,
# non_const_buffer_device is set by apply_algo in memory_planning.py
# when device tensors are present. None for CPU-only programs.
non_const_buffer_device=self.module.meta.get(
"non_const_buffer_device", None
),
)
183 changes: 183 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,3 +2643,186 @@ def forward(self, a, b):
0,
"No tensor should have CUDA device when model runs entirely on CPU",
)

def test_emit_non_const_buffer_device_populated_for_device_tensors(self) -> None:
"""Verify that non_const_buffer_device is emitted into ExecutionPlan when
device-aware memory planning is enabled and non-CPU tensors are present."""
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.passes.propagate_device_pass import (
TARGET_DEVICE_COMPILE_SPEC_KEY,
)
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
]

class DevicePartitioner(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[
CompileSpec("max_value", bytes([4])),
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
],
)

def partition(self, exported_program) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
exported_program.graph_module,
op_support=any_chain(AddSupport()),
)
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_backend(DevicePartitioner())
et_prog = lowered.to_executorch(
config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
)
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
self.assertIsNotNone(
plan.non_const_buffer_device,
"non_const_buffer_device should be set when device tensors are present "
"and enable_non_cpu_memory_planning is True",
)
self.assertGreater(len(plan.non_const_buffer_device), 0)
for entry in plan.non_const_buffer_device:
self.assertEqual(entry.device_type, schema.DeviceType.CUDA)
self.assertEqual(entry.device_index, 0)

def test_emit_non_const_buffer_device_none_for_cpu_only(self) -> None:
"""When all tensors are on CPU, non_const_buffer_device should be None
even with enable_non_cpu_memory_planning=True."""

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
et_prog = edge.to_executorch(
config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
)
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
self.assertIsNone(
plan.non_const_buffer_device,
"non_const_buffer_device should be None for CPU-only programs",
)

def test_emit_non_const_buffer_device_none_when_flag_disabled(self) -> None:
"""Even with device tensors, non_const_buffer_device should be None when
enable_non_cpu_memory_planning is False (default)."""
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.passes.propagate_device_pass import (
TARGET_DEVICE_COMPILE_SPEC_KEY,
)
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
]

class DevicePartitioner(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[
CompileSpec("max_value", bytes([4])),
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
],
)

def partition(self, exported_program) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
exported_program.graph_module,
op_support=any_chain(AddSupport()),
)
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_backend(DevicePartitioner())
# Default: enable_non_cpu_memory_planning=False
et_prog = lowered.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
self.assertIsNone(
plan.non_const_buffer_device,
"non_const_buffer_device should be None when "
"enable_non_cpu_memory_planning is False",
)
Loading