Shortcuts

Source code for torch.export.exported_program

import copy
import dataclasses
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    TYPE_CHECKING,
    Union,
)

if TYPE_CHECKING:
    # Import the following modules during type checking to enable code intelligence features,
    # such as auto-completion in tools like pylance, even when these modules are not explicitly
    # imported in user code.

    import sympy

    from torch.utils._sympy.value_ranges import ValueRanges

import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

from .graph_signature import (  # noqa: F401
    _sig_to_specs,
    ArgumentSpec,
    ConstantArgument,
    ExportGraphSignature,
    InputKind,
    InputSpec,
    OutputKind,
    OutputSpec,
    SymIntArgument,
    TensorArgument,
)


__all__ = [
    "ExportedProgram",
    "ModuleCallEntry",
    "ModuleCallSignature",
]


PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]


[docs]@dataclasses.dataclass class ModuleCallSignature: inputs: List[ArgumentSpec] outputs: List[ArgumentSpec] in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec
[docs]@dataclasses.dataclass class ModuleCallEntry: fqn: str signature: Optional[ModuleCallSignature] = None
[docs]class ExportedProgram: """ Package of a program from :func:`export`. It contains an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata. You can call an ExportedProgram like the original callable traced by :func:`export` with the same calling convention. To perform transformations on the graph, use ``.module`` property to access an :class:`torch.fx.GraphModule`. You can then use `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ to rewrite the graph. Afterwards, you can simply use :func:`export` again to construct a correct ExportedProgram. """ def __init__( self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, graph_signature: ExportGraphSignature, state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], range_constraints: "Dict[sympy.Symbol, Any]", equality_constraints: List[Tuple[Any, Any]], module_call_graph: List[ModuleCallEntry], example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier. ): from torch._export.exported_program import _create_graph_module_for_export from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( InputDim, ) # Remove codegen related things from the graph. It should just be a flat graph. graph._codegen = torch.fx.graph.CodeGen() self._graph_module = _create_graph_module_for_export(root, graph) if isinstance(root, torch.fx.GraphModule): self._graph_module.meta.update(root.meta) self._graph_signature: ExportGraphSignature = graph_signature self._state_dict: Dict[str, Any] = state_dict self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints self._equality_constraints: List[ Tuple[InputDim, InputDim] ] = equality_constraints self._module_call_graph: List[ModuleCallEntry] = module_call_graph self._example_inputs = example_inputs from torch._export.verifier import Verifier if verifier is None: verifier = Verifier assert issubclass(verifier, Verifier) self._verifier = verifier # Validate should be always the last step of the constructor. self.verifier().check(self) @property @compatibility(is_backward_compatible=False) def graph_module(self): return self._graph_module @property @compatibility(is_backward_compatible=False) def graph(self): return self.graph_module.graph @property @compatibility(is_backward_compatible=False) def graph_signature(self): return self._graph_signature @property @compatibility(is_backward_compatible=False) def state_dict(self): return self._state_dict
[docs] @compatibility(is_backward_compatible=False) def parameters(self) -> Iterator[torch.nn.Parameter]: """ Returns an iterator over original module's parameters. """ for _, param in self.named_parameters(): yield param
[docs] @compatibility(is_backward_compatible=False) def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]: """ Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself. """ for param_name in self.graph_signature.parameters: yield param_name, self.state_dict[param_name]
[docs] @compatibility(is_backward_compatible=False) def buffers(self) -> Iterator[torch.Tensor]: """ Returns an iterator over original module buffers. """ for _, buf in self.named_buffers(): yield buf
[docs] @compatibility(is_backward_compatible=False) def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]: """ Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself. """ for buffer_name in self.graph_signature.buffers: yield buffer_name, self.state_dict[buffer_name]
@property @compatibility(is_backward_compatible=False) def range_constraints(self): return self._range_constraints @property @compatibility(is_backward_compatible=False) def equality_constraints(self): return self._equality_constraints @property @compatibility(is_backward_compatible=False) def module_call_graph(self): return self._module_call_graph @property @compatibility(is_backward_compatible=False) def example_inputs(self): return self._example_inputs @property @compatibility(is_backward_compatible=False) def call_spec(self): from torch._export.exported_program import CallSpec if len(self.module_call_graph) == 0: return CallSpec(in_spec=None, out_spec=None) assert self.module_call_graph[0].fqn == "" return CallSpec( in_spec=self.module_call_graph[0].signature.in_spec, out_spec=self.module_call_graph[0].signature.out_spec, ) @property @compatibility(is_backward_compatible=False) def verifier(self) -> Any: return self._verifier @property @compatibility(is_backward_compatible=False) def dialect(self) -> str: return self._verifier.dialect def __call__(self, *args: Any, **kwargs: Any) -> Any: import torch._export.error as error from torch._export import combine_args_kwargs if self.call_spec.in_spec is not None: try: user_args = combine_args_kwargs(args, kwargs) args = fx_pytree.tree_flatten_spec( user_args, self.call_spec.in_spec, exact_structural_match=True ) # type: ignore[assignment] except Exception: _, received_spec = pytree.tree_flatten(user_args) raise TypeError( # noqa: TRY200 "Trying to flatten user inputs with exported input tree spec: \n" f"{self.call_spec.in_spec}\n" "but actually got inputs with tree spec of: \n" f"{received_spec}" ) ordered_params = tuple( self.state_dict[name] for name in self.graph_signature.parameters ) ordered_buffers = tuple( self.state_dict[name] for name in self.graph_signature.buffers ) self._check_input_constraints(*ordered_params, *ordered_buffers, *args) # NOTE: calling convention is first params, then buffers, then args as user supplied them. # See: torch/_functorch/aot_autograd.py#L1034 res = torch.fx.Interpreter(self.graph_module).run( *ordered_params, *ordered_buffers, *args, enable_io_processing=False ) if self.call_spec.out_spec is not None: mutation = self.graph_signature.buffers_to_mutate num_mutated = len(mutation) mutated_buffers = res[:num_mutated] # Exclude dependency token from final result. assertion_dep_token = self.graph_signature.assertion_dep_token if assertion_dep_token is not None: assertion_dep_token_index = next(iter(assertion_dep_token.keys())) res = res[:assertion_dep_token_index] res = res[num_mutated:] try: res = pytree.tree_unflatten(res, self.call_spec.out_spec) except Exception: _, received_spec = pytree.tree_flatten(res) raise error.InternalError( # noqa: TRY200 "Trying to flatten user outputs with exported output tree spec: \n" f"{self.call_spec.out_spec}\n" "but actually got outputs with tree spec of: \n" f"{received_spec}" ) finally: ix = 0 for buffer in self.graph_signature.buffers_to_mutate.values(): self.state_dict[buffer] = mutated_buffers[ix] ix += 1 return res def __str__(self) -> str: graph_module = self.graph_module.print_readable(print_output=False).replace( "\n", "\n " ) string = ( "ExportedProgram:\n" f" {graph_module}\n" f"Graph signature: {self.graph_signature}\n" f"Range constraints: {self.range_constraints}\n" f"Equality constraints: {self.equality_constraints}\n" ) return string
[docs] def module(self, *, flat: bool = True) -> torch.nn.Module: """ Returns a self contained GraphModule with all the parameters/buffers inlined. """ from torch._export.exported_program import unlift_exported_program_lifted_states from torch._export.unflatten import unflatten if flat: return unlift_exported_program_lifted_states(self) else: return unflatten(self)
def run_decompositions( self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None ) -> "ExportedProgram": """ Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in the `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. For now, we do not decompose joint graphs. """ from torch._decomp import core_aten_decompositions from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( _AddRuntimeAssertionsForInlineConstraintsPass, InputDim, ) from torch._export.passes.lift_constant_tensor_pass import ( lift_constant_tensor_pass, ) from torch._export.passes.replace_sym_size_ops_pass import ( _replace_sym_size_ops_pass, ) from torch._functorch.aot_autograd import aot_export_module def _get_placeholders(gm): placeholders = [] for node in gm.graph.nodes: if node.op != "placeholder": break placeholders.append(node) return placeholders decomp_table = decomp_table or core_aten_decompositions() old_placeholders = _get_placeholders(self.graph_module) fake_args = [node.meta["val"] for node in old_placeholders] buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] for name in buffers_to_remove: delattr(self.graph_module, name) # TODO(zhxhchen17) Return the new graph_signature directly. gm, graph_signature = aot_export_module( self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False ) # Update the signatures with the new placeholder names in case they # changed when calling aot_export new_placeholders = _get_placeholders(gm) assert len(new_placeholders) == len(old_placeholders) old_new_placeholder_map = { old_node.name: new_node.name for old_node, new_node in zip(old_placeholders, new_placeholders) } old_outputs = list(self.graph.nodes)[-1].args[0] new_outputs = list(gm.graph.nodes)[-1].args[0] assert len(new_outputs) == len(old_outputs) old_new_output_map = { old_node.name: new_node.name for old_node, new_node in zip(old_outputs, new_outputs) } def make_argument_spec(old_node, node) -> ArgumentSpec: if "val" not in node.meta: assert len(node.users) == 0 val = old_node.meta["val"] else: val = node.meta["val"] if isinstance(val, torch.Tensor): return TensorArgument(name=node.name) elif isinstance(val, torch.SymInt): return SymIntArgument(name=node.name) else: return ConstantArgument(value=val) input_specs, output_specs = _sig_to_specs( user_inputs={ old_new_placeholder_map[inp] for inp in self.graph_signature.user_inputs }, inputs_to_parameters={ old_new_placeholder_map[inp]: param for inp, param in self.graph_signature.inputs_to_parameters.items() }, inputs_to_buffers={ old_new_placeholder_map[inp]: buffer for inp, buffer in self.graph_signature.inputs_to_buffers.items() }, user_outputs={ old_new_output_map[out] for out in self.graph_signature.user_outputs }, buffer_mutations={ old_new_output_map[out]: buffer for out, buffer in self.graph_signature.buffers_to_mutate.items() }, grad_params={}, grad_user_inputs={}, loss_output=None, inputs=[ make_argument_spec(old_placeholders[i], node) for i, node in enumerate(gm.graph.nodes) if node.op == "placeholder" ], outputs=[ make_argument_spec(old_outputs[i], node) for i, node in enumerate( pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) ) ], ) new_graph_signature = ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ) # NOTE: aot_export adds symint metadata for placeholders with int # values; since these become specialized, we replace such metadata with # the original values. # Also, set the param/buffer metadata back to the placeholders. for old_node, new_node in zip(old_placeholders, new_placeholders): if not isinstance(old_node.meta["val"], torch.Tensor): new_node.meta["val"] = old_node.meta["val"] if ( new_node.target in new_graph_signature.inputs_to_parameters or new_node.target in new_graph_signature.inputs_to_buffers ): for k, v in old_node.meta.items(): new_node.meta[k] = v # TODO unfortunately preserving graph-level metadata is not # working well with aot_export. So we manually copy it. # (The node-level meta is addressed above.) gm.meta.update(self.graph_module.meta) new_range_constraints = _get_updated_range_constraints(gm) new_equality_constraints = [ ( InputDim(old_new_placeholder_map[inp_dim1.input_name], inp_dim1.dim), InputDim(old_new_placeholder_map[inp_dim2.input_name], inp_dim2.dim), ) for inp_dim1, inp_dim2 in self.equality_constraints ] state_dict = self.state_dict.copy() lift_constant_tensor_pass(gm, new_graph_signature, state_dict) _replace_sym_size_ops_pass(gm) exported_program = ExportedProgram( gm, gm.graph, new_graph_signature, state_dict, new_range_constraints, new_equality_constraints, copy.deepcopy(self.module_call_graph), self.example_inputs, self.verifier, ) if len(new_range_constraints) > 0 or len(new_equality_constraints) > 0: exported_program = exported_program._transform( _AddRuntimeAssertionsForInlineConstraintsPass( new_range_constraints, new_equality_constraints ) ) return exported_program def _transform(self, *passes: PassType) -> "ExportedProgram": pm = PassManager(list(passes)) res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module assert transformed_gm is not None if transformed_gm is self.graph_module and not res.modified: return self # TODO(zhxchen17) Remove this. def _get_updated_graph_signature( old_signature: ExportGraphSignature, new_gm: torch.fx.GraphModule, ) -> ExportGraphSignature: """ Update the graph signature's user_input/user_outputs. """ new_graph_inputs = [ node.name for node in new_gm.graph.nodes if node.op == "placeholder" ] num_inputs = ( len(old_signature.parameters) + len(old_signature.buffers) + len( [ s for s in old_signature.input_specs if s.kind == InputKind.USER_INPUT ] ) ) assert len(new_graph_inputs) == num_inputs, ( f"Number of input nodes changed from {len(new_graph_inputs)} " f"to {num_inputs} after transformation. This transformation " "is currently not supported." ) num_param_buffers = len(old_signature.buffers) + len( old_signature.parameters ) new_user_inputs = new_graph_inputs[num_param_buffers:] output_node = list(new_gm.graph.nodes)[-1] assert output_node.op == "output" new_graph_outputs = [arg.name for arg in output_node.args[0]] assert len(new_graph_outputs) == len(old_signature.buffers_to_mutate) + len( [ s for s in old_signature.output_specs if s.kind == OutputKind.USER_OUTPUT ] ), ( f"Number of output nodes changed from {len(new_graph_outputs)} " f"to {len(old_signature.buffers_to_mutate) + len(old_signature.user_outputs)} " "after transformation. This transformation is currently not supported." ) new_user_outputs = new_graph_outputs[len(old_signature.buffers_to_mutate) :] def make_argument_spec(node) -> ArgumentSpec: val = node.meta["val"] if isinstance(val, torch.Tensor): return TensorArgument(name=node.name) elif isinstance(val, torch.SymInt): return SymIntArgument(name=node.name) else: return ConstantArgument(value=val) input_specs, output_specs = _sig_to_specs( user_inputs=set(new_user_inputs), inputs_to_parameters=old_signature.inputs_to_parameters, inputs_to_buffers=old_signature.inputs_to_buffers, user_outputs=set(new_user_outputs), buffer_mutations=old_signature.buffers_to_mutate, grad_params={}, grad_user_inputs={}, loss_output=None, inputs=[ make_argument_spec(node) for node in transformed_gm.graph.nodes if node.op == "placeholder" ], outputs=[ make_argument_spec(node) for node in pytree.tree_flatten( next(iter(reversed(transformed_gm.graph.nodes))).args )[0] ], ) new_signature = ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ) return new_signature transformed_ep = ExportedProgram( transformed_gm, transformed_gm.graph, _get_updated_graph_signature(self.graph_signature, transformed_gm), self.state_dict, _get_updated_range_constraints(transformed_gm), copy.deepcopy(self.equality_constraints), copy.deepcopy(self._module_call_graph), self.example_inputs, self.verifier, ) transformed_ep.graph_module.meta.update(self.graph_module.meta) transformed_ep.graph_module.meta.update(res.graph_module.meta) return transformed_ep def _check_input_constraints(self, *args): from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( _AddRuntimeAssertionsForConstraintsPass, ) # TODO(zhxchen17) Don't generate a runtime graph on the fly. _assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph()) for p in self.graph.nodes: if p.op != "placeholder": continue new_p = _assertion_graph.graph.placeholder(p.name) new_p.meta = p.meta _assertion_graph.graph.output(()) _assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass( self.range_constraints, self.equality_constraints, )(_assertion_graph) assert _assertion_graph_res is not None _assertion_graph = _assertion_graph_res.graph_module _assertion_graph(*args) def _validate(self): self.verifier().check(self)
def _get_updated_range_constraints( gm: torch.fx.GraphModule, ) -> "Dict[sympy.Symbol, Any]": def get_shape_env(gm): vals = [ node.meta["val"] for node in gm.graph.nodes if node.meta.get("val", None) is not None ] from torch._guards import detect_fake_mode fake_mode = detect_fake_mode(vals) if fake_mode is not None: return fake_mode.shape_env for v in vals: if isinstance(v, torch.SymInt): return v.node.shape_env shape_env = get_shape_env(gm) if shape_env is None: return {} range_constraints = { k: v for k, v in shape_env.var_to_range.items() if k not in shape_env.replacements } return range_constraints

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources