torch.escape-hatch ====================== assume_constant_result ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch import torch._dynamo as torchdynamo class AssumeConstantResult(torch.nn.Module): """ Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """ def __init__(self): super().__init__() @torchdynamo.assume_constant_result def get_item(self, y): return y.int().item() def forward(self, x, y): return x[: self.get_item(y)] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]): # slice_1: f32[3, 2] = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4); arg0_1 = None return (slice_1,) Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['slice_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Symbol to range: {} constrain_as_size_example ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch `, :doc:`torch.dynamic-value ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch._export.constraints import constrain_as_size def constrain_as_size_example(x): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at constrain_as_value and constrain_as_size APIs constrain_as_size is used for values that NEED to be used for constructing tensor. """ a = x.item() constrain_as_size(a, min=0, max=5) return torch.ones((a, 5)) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: i64[]): # _local_scalar_dense: Sym(i4) = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge: Sym(i4 >= 0) = _local_scalar_dense >= 0 scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None le: Sym(i4 <= 5) = _local_scalar_dense <= 5 scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(le); le = None _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5) full: f32[i4, 5] = torch.ops.aten.full.default([_local_scalar_dense, 5], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None sym_size: Sym(i4) = torch.ops.aten.sym_size.int(full, 0) ge_1: Sym(i4 >= 0) = sym_size >= 0 scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(ge_1); ge_1 = None _assert_async_2 = torch.ops.aten._assert_async.msg(scalar_tensor_2, 'full.shape[0] is outside of inline constraint [0, 5].'); scalar_tensor_2 = None le_1: Sym(i4 <= 5) = sym_size <= 5; sym_size = None scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(le_1); le_1 = None _assert_async_3 = torch.ops.aten._assert_async.msg(scalar_tensor_3, 'full.shape[0] is outside of inline constraint [0, 5].'); scalar_tensor_3 = None return (full,) Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['full'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Symbol to range: {i0: RangeConstraint(min_val=2, max_val=5), i1: RangeConstraint(min_val=2, max_val=5), i2: RangeConstraint(min_val=2, max_val=5), i3: RangeConstraint(min_val=2, max_val=5), i4: RangeConstraint(min_val=2, max_val=5)} constrain_as_value_example ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch `, :doc:`torch.dynamic-value ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch._export.constraints import constrain_as_value def constrain_as_value_example(x, y): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at constrain_as_value and constrain_as_size APIs. constrain_as_value is used for values that don't need to be used for constructing tensor. """ a = x.item() constrain_as_value(a, min=0, max=5) if a < 6: return y.sin() return y.cos() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: i64[], arg1_1: f32[5, 5]): # _local_scalar_dense: Sym(i4) = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge: Sym(i4 >= 0) = _local_scalar_dense >= 0 scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None le: Sym(i4 <= 5) = _local_scalar_dense <= 5 scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(le); le = None _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5); _local_scalar_dense = None sin: f32[5, 5] = torch.ops.aten.sin.default(arg1_1); arg1_1 = None return (sin,) Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['sin'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Symbol to range: {i0: RangeConstraint(min_val=0, max_val=5), i1: RangeConstraint(min_val=0, max_val=5), i2: RangeConstraint(min_val=0, max_val=5), i3: RangeConstraint(min_val=0, max_val=5), i4: RangeConstraint(min_val=0, max_val=5)}