Shortcuts

Source code for torch.distributed.tensor.parallel.style

# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, Union

import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.parallel._utils import (
    _deprecate_warnings,
    _prepare_input_validate,
    _prepare_output_validate,
    _PrepareInputType,
    _PrepareOutputType,
    LayoutsType,
)

__all__ = [
    "ParallelStyle",
    "RowwiseParallel",
    "ColwiseParallel",
    "PairwiseParallel",
    "PrepareModuleInput",
    "PrepareModuleOutput",
    "SequenceParallel",
    "make_input_replicate_1d",
    "make_input_reshard_replicate",
    "make_input_shard_1d",
    "make_input_shard_1d_last_dim",
    "make_sharded_output_tensor",
    "make_output_replicate_1d",
    "make_output_reshard_tensor",
    "make_output_tensor",
    "make_output_shard_1d",
]


class ParallelStyle(ABC):
    """
    The parallel style user wants the module or submodule to be parallelized.

    Users can extend this class to build their own parallel style with customized input/output preparations.

    .. warning::
        ``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will
        remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead.
    """

    _prepare_input: _PrepareInputType
    _prepare_output: _PrepareOutputType
    input_layouts: LayoutsType
    output_layouts: LayoutsType
    use_local_output: bool

    @abstractmethod
    def __init__(
        self,
        _prepare_input,
        _prepare_output,
        *,
        input_layouts,
        output_layouts,
        use_local_output,
    ) -> None:
        self.input_layouts = input_layouts
        self.output_layouts = output_layouts
        self.use_local_output = use_local_output
        self._prepare_input = _prepare_input  # type: ignore[assignment, misc]
        self._prepare_output = _prepare_output  # type: ignore[assignment, misc]


[docs]class PairwiseParallel(ParallelStyle): """ PairwiseParallel concatenate colwise and rowwise styles as a fixed pair. Similar to what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output need to be replicate DTensors. .. warning:: PairwiseParallel can be decomposed into ColwiseParallel and RowwiseParallel. We recommend users to directly use latter instead and we are deprecating this style and will remove it soon. """ def __init__( self, _prepare_input=None, _prepare_output=None, *, input_layouts=None, output_layouts=None, use_local_output=True, ) -> None: _deprecate_warnings( "PairwiseParallel", "Use ColwiseParallel and RowwiseParallel instead." ) _prepare_input = ( make_input_replicate_1d if _prepare_input is None else _prepare_input ) _prepare_output = ( make_output_tensor if _prepare_output is None else _prepare_output ) super().__init__( _prepare_input, _prepare_output, input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, )
[docs]class SequenceParallel(PairwiseParallel): """ SequenceParallel concatenate colwise and rowwise styles as a fixed pair together with sequence parallel. Similar to what Megatron-LM Sequence parallel(https://arxiv.org/pdf/2205.05198.pdf) is doing. We assume both input and output need to be sharded DTensors. .. warning:: SequenceParallel can be decomposed into ColwiseParallel and RowwiseParallel. We recommend users to directly use latter instead and we are deprecating this style and will remove it soon. """ def __init__( self, _prepare_input=None, _prepare_output=None, *, input_layouts=None, output_layouts=None, use_local_output=True, ) -> None: _deprecate_warnings( "SequenceParallel", "Use ColwiseParallel and RowwiseParallel instead." ) super().__init__( # type: ignore[misc] make_input_reshard_replicate, make_output_reshard_tensor, input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, )
[docs]@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] def make_input_shard_1d( input: Union[torch.Tensor, DTensor], device_mesh: Optional[DeviceMesh] = None, dim: int = 0, ) -> DTensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings("make_input_shard_1d", "Specify input_layouts instead.") shard_spec = [Shard(dim)] if isinstance(input, DTensor): return input.redistribute(device_mesh, shard_spec) elif isinstance(input, torch.Tensor): return DTensor.from_local(input, device_mesh, shard_spec, run_check=False) else: raise RuntimeError( "Tensor parallel module expects torch.Tensor or DTensor input but" f" received {type(input)}!" )
[docs]@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] def make_input_shard_1d_last_dim( input: Union[torch.Tensor, DTensor], device_mesh: Optional[DeviceMesh] = None, ) -> DTensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings( "make_input_shard_1d_last_dim", "Specify input_layouts instead." ) return make_input_shard_1d(input, device_mesh, dim=input.dim() - 1) # type: ignore[call-arg, misc]
[docs]@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] def make_input_reshard_replicate( input: torch.Tensor, device_mesh: DeviceMesh, ) -> DTensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings( "make_input_reshard_replicate", "Specify input_layouts instead." ) return make_input_replicate_1d( # type: ignore[call-arg, misc] make_input_shard_1d(input, device_mesh, dim=0), device_mesh # type: ignore[call-arg, misc] )
[docs]@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] def make_input_replicate_1d( input: Union[torch.Tensor, DTensor], device_mesh: Optional[DeviceMesh] = None, ) -> DTensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``input_layouts`` instead. """ _deprecate_warnings("make_input_replicate_1d", "Specify input_layouts instead.") replicate = [Replicate()] if isinstance(input, DTensor): return input.redistribute(device_mesh, replicate) elif isinstance(input, torch.Tensor): return DTensor.from_local(input, device_mesh, replicate, run_check=False) else: raise RuntimeError( "Tensor parallel module expects torch.Tensor or DTensor input but" f" received {type(input)}!" )
[docs]@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_shard_1d( output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0 ) -> DTensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_shard_1d", "Specify output_layouts instead.") return output.redistribute(device_mesh, [Shard(dim)])
[docs]@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_replicate_1d( output: DTensor, device_mesh: Optional[DeviceMesh] = None ) -> DTensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_replicate_1d", "Specify output_layouts instead.") return output.redistribute(device_mesh, [Replicate()])
[docs]@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_tensor( output: DTensor, device_mesh: Optional[DeviceMesh] = None ) -> torch.Tensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_tensor", "Specify output_layouts instead.") return make_output_replicate_1d( # type: ignore[attr-defined, misc] output, device_mesh ).to_local() # type: ignore[call-arg]
@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_sharded_output_tensor( output: DTensor, _device_mesh: Optional[DeviceMesh] = None ) -> torch.Tensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_sharded_output_tensor", "Specify output_layouts instead.") return output.to_local() # type: ignore[call-arg]
[docs]@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_reshard_tensor( output: DTensor, device_mesh: Optional[DeviceMesh] = None, ) -> torch.Tensor: # noqa: D205, D400 """ .. warning:: This method was deprecated and please specify ``output_layouts`` instead. """ _deprecate_warnings("make_output_reshard_tensor", "Specify output_layouts instead.") return make_output_shard_1d(output, device_mesh).to_local() # type: ignore[call-arg, attr-defined, misc]
def _redistribute_per_both_layouts(t, input_layout, output_layout, device_mesh): dst_placements = (output_layout,) if isinstance(t, DTensor): return t.redistribute(device_mesh, dst_placements) elif isinstance(t, torch.Tensor): dtensor = DTensor.from_local(t, device_mesh, (input_layout,), run_check=False) return dtensor.redistribute(device_mesh, dst_placements) else: if input_layout is not None: raise RuntimeError( "Tensor parallel module expects DTensor or tensor" f" when layout specified but received {type(t)}!" ) else: return t def _redistribute_per_layout(layout, use_local_output, t, device_mesh): dst_placements = (layout,) if isinstance(t, DTensor): dtensor = t.redistribute(device_mesh, dst_placements) return dtensor.to_local() if use_local_output else dtensor else: if layout is not None: raise RuntimeError( "Tensor parallel module expects DTensor or tensor" f" when layout specified but received {type(t)}!" ) else: return t
[docs]class RowwiseParallel(ParallelStyle): """ Partition the row of a module. We assume the input to be a sharded :class:`DTensor` and output to be a :class:`torch.Tensor`. Args: input_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which DTensor will be created upon. output_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which created DTensor will be redistributed to. use_local_output (bool): Whether to convert the DTensor to local :class:`torch.Tensor`. Returns: None. .. warning:: RowwiseParallel now only support ``nn.Linear``. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. .. warning:: ``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> parallelize_plan = { >>> "wo": RowwiseParallel(), # The input of Linear will be converted to Sharded DTensor >>> # and we will return a replicate :class:`torch.Tensor` as output. >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ... """ def __init__( self, _prepare_input=None, _prepare_output=None, *, input_layouts=Shard(-1), output_layouts=Replicate(), use_local_output=True, ) -> None: if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple): raise NotImplementedError( "RowwiseParallel only supports single input/output." ) prepare_input_fn = None if _prepare_input is not None: prepare_input_fn = _prepare_input elif input_layouts is not None: prepare_input_fn = functools.partial( RowwiseParallel._prepare_input_fn, input_layouts, Shard(-1), ) prepare_output_fn = None if _prepare_output is not None: prepare_output_fn = _prepare_output elif output_layouts is not None: prepare_output_fn = functools.partial( _redistribute_per_layout, output_layouts, use_local_output ) super().__init__( input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, _prepare_input=prepare_input_fn, _prepare_output=prepare_output_fn, ) @staticmethod def _prepare_input_fn(input_layouts, output_layouts, inputs, device_mesh=None): return ( _redistribute_per_both_layouts( inputs[0], input_layouts, output_layouts, device_mesh ), )
[docs]class ColwiseParallel(ParallelStyle): """ Partition the column of a tensor or module. We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`torch.Tensor`. Args: input_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which DTensor will be created upon. output_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which created DTensor will be redistributed to. use_local_output (bool): Whether to convert the DTensor to local :class:`torch.Tensor`. Returns: None. .. warning:: ColwiseParallel now only support ``nn.Linear`` and ``nn.Embedding``. Users can compose it with RowwiseParallel to achieve the sharding of more complicated modules. .. warning:: ``_prepare_input`` and ``_prepare_output`` are only for internal usage and we will remove them from ctor soon. Please use ``input_layouts`` and ``output_layouts`` instead. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> parallelize_plan = { >>> "w1": ColwiseParallel(), # The input of Linear will be converted to Replicated DTensor >>> # and we will return a sharded :class:`torch.Tensor` as output. >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ... """ def __init__( self, _prepare_input=None, _prepare_output=None, *, input_layouts=Replicate(), output_layouts=Shard(-1), use_local_output=True, ) -> None: if isinstance(input_layouts, tuple) or isinstance(output_layouts, tuple): raise NotImplementedError( "ColwiseParallel only supports single input/output." ) prepare_input_fn = None if _prepare_input is not None: prepare_input_fn = _prepare_input elif input_layouts is not None: prepare_input_fn = functools.partial( ColwiseParallel._prepare_input_fn, input_layouts, Replicate(), ) prepare_output_fn = None if _prepare_output is not None: prepare_output_fn = _prepare_output elif output_layouts is not None: prepare_output_fn = functools.partial( _redistribute_per_layout, output_layouts, use_local_output ) super().__init__( _prepare_input=prepare_input_fn, _prepare_output=prepare_output_fn, input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, ) @staticmethod def _prepare_input_fn(input_layouts, output_layouts, inputs, device_mesh=None): return ( _redistribute_per_both_layouts( inputs[0], input_layouts, output_layouts, device_mesh ), )
[docs]class PrepareModuleInput(ParallelStyle): """ Annotate Tensor inputs with layouts for conversion to DTensor, redistributing based on specified output layouts. :class:`PrepareModuleInput` enables users to annotate :class:`torch.Tensor` or :class:`DTensor` inputs with ``input_layouts`` and ``output_layouts`` so that each input can be converted to :class:`DTensor` based on the annotation. Specifically, a DTensor will be created from the input Tensor based on ``input_layouts`` and then redistributed to another DTensor based on ``output_layouts``. When the input is not a :class:`torch.Tensor` or :class:`DTensor`, if no layout is specified, it will be a no-op. Otherwise, it will throw an error. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> ... >>> parallelize_plan = { >>> "attn": PrepareModuleInput(), # The input of attn will be converted to Sharded DTensor >>> # and and redistributed to Replicated DTensor. >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ... """ def __init__( self, input_layouts: LayoutsType = Shard(0), output_layouts: LayoutsType = Replicate(), use_local_output: bool = False, ) -> None: # noqa: D205, D400 """ Args: input_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which DTensor will be created upon. output_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of input tensor(s) which created DTensor will be redistributed to. use_local_output (bool): Whether to convert the DTensor to local :class:`torch.Tensor`. Returns: None. """ if not isinstance(input_layouts, (list, tuple)): input_layouts = (input_layouts,) # type: ignore[assignment] if not isinstance(output_layouts, (list, tuple)): output_layouts = (output_layouts,) # type: ignore[assignment] super().__init__( input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, _prepare_input=functools.partial( PrepareModuleInput._make_input_redistribute_1d, input_layouts, output_layouts, ), _prepare_output=None, ) @staticmethod def _make_input_redistribute_1d( input_layouts: LayoutsType, output_layouts: LayoutsType, inputs: Tuple[Any, ...], device_mesh: Optional[DeviceMesh] = None, ) -> Optional[Any]: """Redistribute inputs over a device mesh.""" # Always assume layouts are tuples. results = [] for input, input_layout, output_layout in zip( inputs, input_layouts, output_layouts # type: ignore[arg-type] ): if input_layout is None: results.append(input) else: results.append( _redistribute_per_both_layouts( input, input_layout, output_layout, device_mesh ) ) return tuple(results)
[docs]class PrepareModuleOutput(ParallelStyle): """ Enable annotation of DTensor outputs for flexible conversion to torch.Tensor based on specified layouts. :class:`PrepareModuleOutput` enables users to annotate :class:`DTensor` outputs with ``output_layouts`` and ``use_local_output`` so that each output can be converted to :class:`DTensor` or :class:`torch.Tensor` based on the annotation. Specifically, a DTensor will be redistributed to another DTensor based on ``output_layouts`` and the flag ``use_local_output`` to decide whether to convert the DTensor to local :class:`torch.Tensor`. When the output is not a :class:`DTensor`, if no layout is specified, it will be a no-op. Otherwise, it will throw an error. Example:: >>> # xdoctest: +SKIP(failing) >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> ... >>> parallelize_plan = { >>> "submodule": PrepareModuleOutput(), # The output of submodule will be converted to Replicated DTensor >>> # if it's not a DTensor, then redistributed to Sharded local tensor >>> ... >>> } >>> parallelize_module( >>> module=block, # this can be a submodule or module >>> ..., >>> parallelize_plan=parallelize_plan, >>> ) >>> ... """ def __init__( self, input_layouts: LayoutsType = Replicate(), output_layouts: LayoutsType = Shard(0), use_local_output: bool = True, ) -> None: # noqa: D205, D400 """ Args: input_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of output DTensor(s). output_layouts (Union[Placement, Tuple[Placement, ...]]): The layout of output DTensor(s)/tensor(s) which output DTensor(s) will be redistributed to. use_local_output (bool): Whether to convert the DTensor to local :class:`torch.Tensor`. Returns: None. """ if isinstance(output_layouts, (list, tuple)): if len(output_layouts) != 1: raise NotImplementedError( "Only support single output redistribute now." ) else: output_layouts = output_layouts[0] super().__init__( input_layouts=input_layouts, output_layouts=output_layouts, use_local_output=use_local_output, _prepare_input=None, _prepare_output=functools.partial( _redistribute_per_layout, output_layouts, use_local_output ), )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources