python.control-flow¶
dynamic_shape_if_guard¶
Original source code:
import torch
class DynamicShapeIfGuard(torch.nn.Module):
"""
`if` statement with backed dynamic shape predicate will be specialized into
one particular branch and generate a guard. However, export will fail if the
the dimension is marked as dynamic shape from higher level API.
"""
def forward(self, x):
if x.shape[0] == 3:
return x.cos()
return x.sin()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2, 2]"):
cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return (cos,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}
Equality constraints: []
list_unpack¶
Original source code:
from typing import List
import torch
def list_unpack(args: List[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
x, *y = args
return x + y[0]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]", arg2_1: "i64[]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2_1'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
Equality constraints: []
static_for_loop¶
Original source code:
import torch
class StaticForLoop(torch.nn.Module):
"""
A for loop with constant number of iterations should be unrolled in the exported graph.
"""
def __init__(self):
super().__init__()
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
return ret
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 0)
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1)
add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 2)
add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 3)
add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 4)
add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 5)
add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 6)
add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 7)
add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 8)
add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 9); arg0_1 = None
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
Range constraints: {}
Equality constraints: []
static_if¶
Original source code:
import torch
class StaticIf(torch.nn.Module):
"""
`if` statement with static predicate value should be traced through with the
taken branch.
"""
def __init__(self):
super().__init__()
def forward(self, x):
if len(x.shape) == 3:
return x + torch.ones(1, 1, 1)
return x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2, 2]"):
ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(arg0_1, ones); arg0_1 = ones = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
Equality constraints: []