• Docs >
  • Tensor Parallelism - torch.distributed.tensor.parallel
Shortcuts

Tensor Parallelism - torch.distributed.tensor.parallel

Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides several parallelism styles: Rowwise and Colwise Parallelism.

Warning

Tensor Parallelism APIs are experimental and subject to change.

The entrypoint to parallelize your nn.Module using Tensor Parallelism is:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]

Apply Tensor Parallelism (TP) in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.

We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains ParallelStyle, which indicates how user wants the module or sub_module to be parallelized.

User can also specify different parallel style per module fully qualified name (FQN). The API supports 2D parallelism natively by accepting an n-dimension device_mesh and users just need to specify the dimension where we perform tensor parallelism on.

Parameters
  • module (nn.Module) – Module to be parallelized.

  • device_mesh (DeviceMesh) – Object which describes the mesh topology of devices for the DTensor.

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]]) – The plan used to parallelize the module. It can be either a ParallelStyle object which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its corresponding ParallelStyle object.

  • tp_mesh_dim (int) – The dimension of device_mesh where we perform Tensor Parallelism on.

Returns

A nn.Module object parallelized.

Return type

Module

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>>
>>> # Define the module.
>>> m = Model(...)
>>> m = parallelize_module(m, ColwiseParallel())
>>>

Warning

Currently, there are some constraints which makes it hard for complicated modules like MultiheadAttention to work out of box for Tensor or Sequence Parallelism. We recommend users to try ColwiseParallel and RowwiseParallel for each parameter or submodule and there might be some code changes needed now.

Tensor Parallelism supports the following parallel styles:

class torch.distributed.tensor.parallel.style.RowwiseParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=Shard(dim=-1), output_layouts=Replicate(), use_local_output=True)[source]

Partition the row of a module.

We assume the input to be a sharded DTensor and output to be a torch.Tensor.

Parameters
  • input_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which DTensor will be created upon.

  • output_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which created DTensor will be redistributed to.

  • use_local_output (bool) – Whether to convert the DTensor to local torch.Tensor.

Returns

None.

Return type

None

Warning

RowwiseParallel now only support nn.Linear. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.

Warning

_prepare_input and _prepare_output are only for internal usage and we will remove them from ctor soon. Please use input_layouts and output_layouts instead.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> ...
>>> parallelize_plan = {
>>>     "wo": RowwiseParallel(),   # The input of Linear will be converted to Sharded DTensor
>>>                                # and we will return a replicate :class:`torch.Tensor` as output.
>>>     ...
>>> }
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan=parallelize_plan,
>>> )
>>> ...
class torch.distributed.tensor.parallel.style.ColwiseParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=Replicate(), output_layouts=Shard(dim=-1), use_local_output=True)[source]

Partition the column of a tensor or module.

We assume the input to be a replicated DTensor and output to be a sharded torch.Tensor.

Parameters
  • input_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which DTensor will be created upon.

  • output_layouts (Union[Placement, Tuple[Placement, ...]]) – The layout of input tensor(s) which created DTensor will be redistributed to.

  • use_local_output (bool) – Whether to convert the DTensor to local torch.Tensor.

Returns

None.

Return type

None

Warning

ColwiseParallel now only support nn.Linear and nn.Embedding. Users can compose it with RowwiseParallel to achieve the sharding of more complicated modules.

Warning

_prepare_input and _prepare_output are only for internal usage and we will remove them from ctor soon. Please use input_layouts and output_layouts instead.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> ...
>>> parallelize_plan = {
>>>     "w1": ColwiseParallel(),   # The input of Linear will be converted to Replicated DTensor
>>>                                # and we will return a sharded :class:`torch.Tensor` as output.
>>>     ...
>>> }
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan=parallelize_plan,
>>> )
>>> ...

Warning

We are deprecating the styles below and will remove them soon:

class torch.distributed.tensor.parallel.style.PairwiseParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=None, output_layouts=None, use_local_output=True)[source]

PairwiseParallel concatenate colwise and rowwise styles as a fixed pair.

Similar to what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output need to be replicate DTensors.

Warning

PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel. We recommend users to directly use latter instead and we are deprecating this style and will remove it soon.

class torch.distributed.tensor.parallel.style.SequenceParallel(_prepare_input=None, _prepare_output=None, *, input_layouts=None, output_layouts=None, use_local_output=True)[source]

SequenceParallel concatenate colwise and rowwise styles as a fixed pair together with sequence parallel.

Similar to what Megatron-LM Sequence parallel(https://arxiv.org/pdf/2205.05198.pdf) is doing. We assume both input and output need to be sharded DTensors.

Warning

SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel. We recommend users to directly use latter instead and we are deprecating this style and will remove it soon.

Since Tensor Parallelism is built on top of DTensor, we need to specify the DTensor layout of the input and output of the module so it can interact with the module parameters and module afterwards. Users can achieve this by specifying the input_layouts and output_layouts which annotate inputs as DTensors and redistribute the outputs, if needed.

If users only want to annotate the DTensor layout for inputs/outputs and no need to distribute its parameters, the following classes can be used in the parallelize_plan of parallelize_module:

torch.distributed.tensor.parallel.style.PrepareModuleInput(input_layouts=Shard(dim=0), output_layouts=Replicate(), use_local_output=False)[source]

Annotate Tensor inputs with layouts for conversion to DTensor, redistributing based on specified output layouts.

PrepareModuleInput enables users to annotate torch.Tensor or DTensor inputs with input_layouts and output_layouts so that each input can be converted to DTensor based on the annotation. Specifically, a DTensor will be created from the input Tensor based on input_layouts and then redistributed to another DTensor based on output_layouts.

When the input is not a torch.Tensor or DTensor, if no layout is specified, it will be a no-op. Otherwise, it will throw an error.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> ...
>>> parallelize_plan = {
>>>     "attn": PrepareModuleInput(),   # The input of attn will be converted to Sharded DTensor
>>>                                     # and and redistributed to Replicated DTensor.
>>>     ...
>>> }
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan=parallelize_plan,
>>> )
>>> ...
torch.distributed.tensor.parallel.style.PrepareModuleOutput(input_layouts=Replicate(), output_layouts=Shard(dim=0), use_local_output=True)[source]

Enable annotation of DTensor outputs for flexible conversion to torch.Tensor based on specified layouts.

PrepareModuleOutput enables users to annotate DTensor outputs with output_layouts and use_local_output so that each output can be converted to DTensor or torch.Tensor based on the annotation. Specifically, a DTensor will be redistributed to another DTensor based on output_layouts and the flag use_local_output to decide whether to convert the DTensor to local torch.Tensor.

When the output is not a DTensor, if no layout is specified, it will be a no-op. Otherwise, it will throw an error.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> ...
>>> parallelize_plan = {
>>>     "submodule": PrepareModuleOutput(),   # The output of submodule will be converted to Replicated DTensor
>>>                                           # if it's not a DTensor, then redistributed to Sharded local tensor
>>>     ...
>>> }
>>> parallelize_module(
>>>     module=block, # this can be a submodule or module
>>>     ...,
>>>     parallelize_plan=parallelize_plan,
>>> )
>>> ...

Warning

We are deprecating the methods below and will remove them soon:

torch.distributed.tensor.parallel.style.make_input_replicate_1d(input, device_mesh=None)[source]

Warning

This method was deprecated and please specify input_layouts instead.

Return type

DTensor

torch.distributed.tensor.parallel.style.make_input_reshard_replicate(input, device_mesh)[source]

Warning

This method was deprecated and please specify input_layouts instead.

Return type

DTensor

torch.distributed.tensor.parallel.style.make_input_shard_1d(input, device_mesh=None, dim=0)[source]

Warning

This method was deprecated and please specify input_layouts instead.

Return type

DTensor

torch.distributed.tensor.parallel.style.make_input_shard_1d_last_dim(input, device_mesh=None)[source]

Warning

This method was deprecated and please specify input_layouts instead.

Return type

DTensor

torch.distributed.tensor.parallel.style.make_output_replicate_1d(output, device_mesh=None)[source]

Warning

This method was deprecated and please specify output_layouts instead.

Return type

DTensor

torch.distributed.tensor.parallel.style.make_output_reshard_tensor(output, device_mesh=None)[source]

Warning

This method was deprecated and please specify output_layouts instead.

Return type

Tensor

torch.distributed.tensor.parallel.style.make_output_shard_1d(output, device_mesh=None, dim=0)[source]

Warning

This method was deprecated and please specify output_layouts instead.

Return type

DTensor

torch.distributed.tensor.parallel.style.make_output_tensor(output, device_mesh=None)[source]

Warning

This method was deprecated and please specify output_layouts instead.

Return type

Tensor

Currently, there are some constraints which makes it hard for the MultiheadAttention module to work out of box for Tensor Parallelism, so we recommend users to try ColwiseParallel and RowwiseParallel for each parameter. There might be some code changes needed now since we are parallelizing on the head dim of the MultiheadAttention module.

We also support 2D parallelism, where we compose tensor parallelism with data parallelism. To integrate with FullyShardedDataParallel, users just need to call the following API explicitly:

torch.distributed.tensor.parallel.fsdp.enable_2d_with_fsdp()[source]

Register the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP).

We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor.

Returns

A bool indicated whether extension registration succeeds or not.

Return type

bool

To integrate with DistributedDataParallel, users just need to call the following API explicitly:

torch.distributed.tensor.parallel.ddp.pre_dp_module_transform(module)[source]

Enable the composability between Tensor Parallelism (TP) and Data Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which are DTensors to local tensors before wrapping with data parallelism API. We then register two hooks, one for converting local tensors back to DTensor preforward and one to convert DTensors back to tensors after Forward. By integrating this way, we avoid any special handling of DTensor parameters by DDP and get DTensor’s gradients propagated back to DP, e.g. gradient buckets of DDP.

For now, this API only works with DistributedDataParallel. It will later support other DP methods such as FSDP.

Parameters

module (nn.Module) – Module which has been applied TP on.

Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
>>>
>>> # Define the module.
>>> m = module(...)
>>> parallelize_module(m, PairwiseParallel())
>>> m = pre_dp_module_transform(m)
>>> m = DDP(m)
>>>

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources