Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions backends/cadence/aot/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ fbcode_target(_kind = runtime.python_library,
],
typing = True,
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/exir:pass_base",
"//executorch/backends/test:graph_builder",
],
)

Expand All @@ -239,11 +238,7 @@ fbcode_target(_kind = runtime.python_library,
],
typing = True,
deps = [
":graph_builder",
"fbcode//caffe2:torch",
"fbcode//executorch/exir:lib",
"fbcode//executorch/exir:pass_base",
"fbcode//executorch/exir/verification:verifier",
"//executorch/backends/test:program_builder",
],
)

Expand All @@ -254,7 +249,7 @@ fbcode_target(_kind = python_unittest,
],
typing = True,
deps = [
":program_builder",
"//executorch/backends/test:program_builder",
"//caffe2:torch",
"//later:lib",
],
Expand Down Expand Up @@ -398,7 +393,7 @@ fbcode_target(_kind = python_unittest,
":typing_stubs",
":type_dispatch",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand Down Expand Up @@ -438,7 +433,7 @@ fbcode_target(_kind = python_unittest,
deps = [
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand All @@ -459,7 +454,7 @@ fbcode_target(_kind = python_unittest,
":replace_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand All @@ -480,7 +475,7 @@ fbcode_target(_kind = python_unittest,
"//caffe2:torch",
":typing_stubs",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand All @@ -501,7 +496,7 @@ fbcode_target(_kind = python_unittest,
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir/dialects:lib",
Expand All @@ -522,7 +517,7 @@ fbcode_target(_kind = python_unittest,
":compiler",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:remove_ops",
Expand All @@ -542,7 +537,7 @@ fbcode_target(_kind = python_unittest,
":typing_stubs",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:simplify_ops",
Expand All @@ -562,7 +557,7 @@ fbcode_target(_kind = python_unittest,
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:fuse_ops",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:reorder_ops",
Expand Down Expand Up @@ -632,11 +627,11 @@ fbcode_target(_kind = python_unittest,
":typing_stubs",
":ops_registrations",
":pass_utils",
":program_builder",
"//executorch/backends/test:program_builder",
"//caffe2:torch",
"//executorch/exir:memory",
"//executorch/exir/dialects:lib",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/exir/tests:models",
],
)
Expand All @@ -648,8 +643,7 @@ fbcode_target(_kind = python_unittest,
],
typing = True,
deps = [
":program_builder",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/runtime:runtime",
"//later:lib",
Expand Down Expand Up @@ -679,7 +673,7 @@ fbcode_target(_kind = python_unittest,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/test:graph_builder",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:pass_base",
"//pytorch/ao:torchao",
Expand All @@ -694,7 +688,7 @@ fbcode_target(_kind = python_unittest,
typing = True,
deps = [
":ops_registrations",
":program_builder",
"//executorch/backends/test:program_builder",
":to_out_var_pass",
"//caffe2:torch",
"//executorch/exir:lib",
Expand Down
137 changes: 4 additions & 133 deletions backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,137 +6,8 @@

# pyre-strict

import logging
from typing import Optional, Sequence, Union
# This module has moved to executorch.backends.test.graph_builder.
# This re-export exists for backward compatibility.
from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder

import torch
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.node import Target
from torch.utils import _pytree as pytree


class GraphBuilder(ExportPass):
"""Utility class for creating a graph module with user-specified ops.

This class allows us to create test graph modules with any ops we want
directly, rather than relying on decomposition or passes.

Usage:
builder = GraphBuilder()
# To insert placeholders, use builder.placeholder.
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
# To insert an op, use builder.call_operator.
op = builder.call_operator(
some_op
(x, other_args, ...),
)
# Insert outputs as a list of ProxyValues using builder.output.
builder.output([op])
# Get GraphModule from builder.
gm = builder.get_graph_module()
"""

def __init__(self, fake_tensor_mode: Optional[FakeTensorMode] = None) -> None:
self.exporter = ExportPass()
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
self, torch.fx.graph.CodeGen()
)
self.fake_tensor_mode: FakeTensorMode = fake_tensor_mode or FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=True,
)
self.tracer.fake_tensor_mode = self.fake_tensor_mode

# This will be called to create nodes in tracer.
self.interpreter = torch.fx.Interpreter(
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)

# pyre-ignore[14]: Inconsistent override.
def placeholder(
self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
) -> ProxyValue:
if not isinstance(fake_tensor, FakeTensor):
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
logging.debug(f"Creating placeholder {target} => {fake_tensor.shape}")
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
return placeholder

# pyre-ignore[14]: Inconsistent override.
def output(self, results: list[ProxyValue]) -> ProxyValue:
logging.debug(f"Creating outputs {results}")
return super().output(results, NodeMetadata({}))

def get_graph_module(self) -> torch.fx.GraphModule:
return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: Optional[dict[str, Argument]] = None,
meta: Optional[NodeMetadata] = None,
) -> ProxyValue:
if meta is None:
meta = NodeMetadata({})
if kwargs is None:
kwargs = {}
return super().call_operator(op, args, kwargs, meta)

def call_submodule(
self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...]
) -> PassResult:
return ExportPass().call(graph_module)

def call_getitem(
self, value: ProxyValue, key: int, meta: Optional[NodeMetadata] = None
) -> ProxyValue:
return super().call_getitem(value, key, meta or NodeMetadata({}))

def _fx(
self,
kind: str,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
with self.fake_tensor_mode, enable_python_dispatcher():
return super()._fx(kind, target, args, kwargs, meta)


def single_op_builder(
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
op: Target,
args: Sequence[Argument],
kwargs: Optional[dict[str, Argument]] = None,
) -> torch.fx.GraphModule:
"""Create a graph module with a single op.

Args:
placeholders: Placeholders to be used as inputs to the GraphModule.
op: The op to be inserted.
args: The args to be passed to the op.
kwargs: The kwargs to be passed to the op.

Returns:
A graph module with a single op
"""
builder = GraphBuilder()
op_to_placeholder_dict = {
p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
}
proxy_args, proxy_kwargs = pytree.tree_map_only(
(torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
)
node = builder.call_operator(op, proxy_args, proxy_kwargs)
builder.output([node])
return builder.get_graph_module()
__all__ = ["GraphBuilder", "single_op_builder"]
Loading
Loading