Tensor Parallelism - torch.distributed.tensor.parallel¶
Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides several parallelism styles: Rowwise and Colwise Parallelism.
Warning
Tensor Parallelism APIs are experimental and subject to change.
The entrypoint to parallelize your nn.Module
using Tensor Parallelism is:
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]¶
Apply Tensor Parallelism (TP) in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
ParallelStyle
, which indicates how user wants the module or sub_module to be parallelized.User can also specify different parallel style per module fully qualified name (FQN). The API supports 2D parallelism natively by accepting an n-dimension device_mesh and users just need to specify the dimension where we perform tensor parallelism on.
- Parameters
module (
nn.Module
) – Module to be parallelized.device_mesh (
DeviceMesh
) – Object which describes the mesh topology of devices for the DTensor.parallelize_plan (Union[
ParallelStyle
, Dict[str,ParallelStyle
]]) – The plan used to parallelize the module. It can be either aParallelStyle
object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its correspondingParallelStyle
object.tp_mesh_dim (int) – The dimension of
device_mesh
where we perform Tensor Parallelism on.
- Returns
A
nn.Module
object parallelized.- Return type
- Module
- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> >>> # Define the module. >>> m = Model(...) >>> m = parallelize_module(m, ColwiseParallel()) >>>
Warning
Currently, there are some constraints which makes it hard for complicated modules like
MultiheadAttention
to work out of box for Tensor or Sequence Parallelism. We recommend users to tryColwiseParallel
andRowwiseParallel
for each parameter or submodule and there might be some code changes needed now.
Tensor Parallelism supports the following parallel styles:
- class torch.distributed.tensor.parallel.style.RowwiseParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=Shard(dim=-1), output_layouts=Replicate(), use_local_output=True)[source]¶
Partition the row of a module.
We assume the input to be a sharded
DTensor
and output to be atorch.Tensor
.- Parameters
input_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which DTensor will be created upon.
output_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which created DTensor will be redistributed to.
use_local_output (bool) – Whether to convert the DTensor to local
torch.Tensor
.
- Returns
None.
- Return type
None
Warning
RowwiseParallel now only support
nn.Linear
. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.Warning
_prepare_input
and_prepare_output
are only for internal usage and we will remove them from ctor soon. Please useinput_layouts
andoutput_layouts
instead.- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> parallelize_plan = { >>> "wo": RowwiseParallel(), # The input of Linear will be converted to Sharded DTensor >>> # and we will return a replicate :class:`torch.Tensor` as output. >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ...
- class torch.distributed.tensor.parallel.style.ColwiseParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=Replicate(), output_layouts=Shard(dim=-1), use_local_output=True)[source]¶
Partition the column of a tensor or module.
We assume the input to be a replicated
DTensor
and output to be a shardedtorch.Tensor
.- Parameters
input_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which DTensor will be created upon.
output_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which created DTensor will be redistributed to.
use_local_output (bool) – Whether to convert the DTensor to local
torch.Tensor
.
- Returns
None.
- Return type
None
Warning
ColwiseParallel now only support
nn.Linear
andnn.Embedding
. Users can compose it with RowwiseParallel to achieve the sharding of more complicated modules.Warning
_prepare_input
and_prepare_output
are only for internal usage and we will remove them from ctor soon. Please useinput_layouts
andoutput_layouts
instead.- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> parallelize_plan = { >>> "w1": ColwiseParallel(), # The input of Linear will be converted to Replicated DTensor >>> # and we will return a sharded :class:`torch.Tensor` as output. >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ...
Warning
We are deprecating the styles below and will remove them soon:
- class torch.distributed.tensor.parallel.style.PairwiseParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=None, output_layouts=None, use_local_output=True)[source]¶
PairwiseParallel concatenate colwise and rowwise styles as a fixed pair.
Similar to what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output need to be replicate DTensors.
Warning
PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel. We recommend users to directly use latter instead and we are deprecating this style and will remove it soon.
- class torch.distributed.tensor.parallel.style.SequenceParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=None, output_layouts=None, use_local_output=True)[source]¶
SequenceParallel concatenate colwise and rowwise styles as a fixed pair together with sequence parallel.
Similar to what Megatron-LM Sequence parallel(https://arxiv.org/pdf/2205.05198.pdf) is doing. We assume both input and output need to be sharded DTensors.
Warning
SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel. We recommend users to directly use latter instead and we are deprecating this style and will remove it soon.
Since Tensor Parallelism is built on top of DTensor, we need to specify the
DTensor layout of the input and output of the module so it can interact with
the module parameters and module afterwards. Users can achieve this by specifying
the input_layouts
and output_layouts
which annotate inputs as DTensors
and redistribute the outputs, if needed.
If users only want to annotate the DTensor layout for inputs/outputs and no need to
distribute its parameters, the following classes can be used in the parallelize_plan
of parallelize_module
:
- torch.distributed.tensor.parallel.style.PrepareModuleInput(input_layouts=Shard(dim=0), output_layouts=Replicate(), use_local_output=False)[source]¶
Annotate Tensor inputs with layouts for conversion to DTensor, redistributing based on specified output layouts.
PrepareModuleInput
enables users to annotatetorch.Tensor
orDTensor
inputs withinput_layouts
andoutput_layouts
so that each input can be converted toDTensor
based on the annotation. Specifically, a DTensor will be created from the input Tensor based oninput_layouts
and then redistributed to another DTensor based onoutput_layouts
.When the input is not a
torch.Tensor
orDTensor
, if no layout is specified, it will be a no-op. Otherwise, it will throw an error.- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> parallelize_plan = { >>> "attn": PrepareModuleInput(), # The input of attn will be converted to Sharded DTensor >>> # and and redistributed to Replicated DTensor. >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ...
- torch.distributed.tensor.parallel.style.PrepareModuleOutput(input_layouts=Replicate(), output_layouts=Shard(dim=0), use_local_output=True)[source]¶
Enable annotation of DTensor outputs for flexible conversion to torch.Tensor based on specified layouts.
PrepareModuleOutput
enables users to annotateDTensor
outputs withoutput_layouts
anduse_local_output
so that each output can be converted toDTensor
ortorch.Tensor
based on the annotation. Specifically, a DTensor will be redistributed to another DTensor based onoutput_layouts
and the flaguse_local_output
to decide whether to convert the DTensor to localtorch.Tensor
.When the output is not a
DTensor
, if no layout is specified, it will be a no-op. Otherwise, it will throw an error.- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> ... >>> parallelize_plan = { >>> "submodule": PrepareModuleOutput(), # The output of submodule will be converted to Replicated DTensor >>> # if it's not a DTensor, then redistributed to Sharded local tensor >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ...
Warning
We are deprecating the methods below and will remove them soon:
- torch.distributed.tensor.parallel.style.make_input_replicate_1d(input, device_mesh=None)[source]¶
Warning
This method was deprecated and please specify
input_layouts
instead.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_input_reshard_replicate(input, device_mesh)[source]¶
Warning
This method was deprecated and please specify
input_layouts
instead.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_input_shard_1d(input, device_mesh=None, dim=0)[source]¶
Warning
This method was deprecated and please specify
input_layouts
instead.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_input_shard_1d_last_dim(input, device_mesh=None)[source]¶
Warning
This method was deprecated and please specify
input_layouts
instead.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_output_replicate_1d(output, device_mesh=None)[source]¶
Warning
This method was deprecated and please specify
output_layouts
instead.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_output_reshard_tensor(output, device_mesh=None)[source]¶
Warning
This method was deprecated and please specify
output_layouts
instead.- Return type
- torch.distributed.tensor.parallel.style.make_output_shard_1d(output, device_mesh=None, dim=0)[source]¶
Warning
This method was deprecated and please specify
output_layouts
instead.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_output_tensor(output, device_mesh=None)[source]¶
Warning
This method was deprecated and please specify
output_layouts
instead.- Return type
Currently, there are some constraints which makes it hard for the MultiheadAttention
module to work out of box for Tensor Parallelism, so we recommend users to try ColwiseParallel
and RowwiseParallel
for each parameter. There might be some code changes needed now
since we are parallelizing on the head dim of the MultiheadAttention
module.
We also support 2D parallelism, where we compose tensor parallelism with data parallelism.
To integrate with FullyShardedDataParallel
,
users just need to call the following API explicitly:
- torch.distributed.tensor.parallel.fsdp.enable_2d_with_fsdp()[source]¶
Register the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP).
We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor.
- Returns
A bool indicated whether extension registration succeeds or not.
- Return type
To integrate with DistributedDataParallel
,
users just need to call the following API explicitly:
- torch.distributed.tensor.parallel.ddp.pre_dp_module_transform(module)[source]¶
Enable the composability between Tensor Parallelism (TP) and Data Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which are DTensors to local tensors before wrapping with data parallelism API. We then register two hooks, one for converting local tensors back to DTensor preforward and one to convert DTensors back to tensors after Forward. By integrating this way, we avoid any special handling of DTensor parameters by DDP and get DTensor’s gradients propagated back to DP, e.g. gradient buckets of DDP.
For now, this API only works with
DistributedDataParallel
. It will later support other DP methods such as FSDP.- Parameters
module (
nn.Module
) – Module which has been applied TP on.
- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform >>> >>> # Define the module. >>> m = module(...) >>> parallelize_module(m, PairwiseParallel()) >>> m = pre_dp_module_transform(m) >>> m = DDP(m) >>>