Tensor Parallelism - torch.distributed.tensor.parallel¶
Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor (DTensor) and provides several parallelism styles: Rowwise, Colwise and Pairwise Parallelism.
Warning
Tensor Parallelism APIs are experimental and subject to change.
The entrypoint to parallelize your nn.Module using Tensor Parallelism is:
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]¶
The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
ParallelStyle, which indicates how user wants the module or sub_module to be parallelized.User can also specify different parallel style per module fully qualified name (FQN). The API supports 2D parallelism natively by accepting an n-dimension device_mesh and users just need to specify the dimension where we perform tensor parallelism on.
- Parameters
module (
nn.Module) – Module to be parallelized.device_mesh (
DeviceMesh) – Object which describes the mesh topology of devices for the DTensor.parallelize_plan (Union[
ParallelStyle, Dict[str,ParallelStyle]]) – The plan used to parallelize the module. It can be either aParallelStyleobject which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its correspondingParallelStyleobject.tp_mesh_dim (int) – The dimension of
device_meshwhere we perform Tensor Parallelism on.
- Returns
A
nn.Moduleobject parallelized.- Return type
- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel >>> >>> # Define the module. >>> m = Model(...) >>> m = parallelize_module(m, PairwiseParallel()) >>>
Warning
PairwiseParallelcomes with constraints for now. If you need finer granularity, you need to pass in a dict of module FQN and parallel style instead.
Tensor Parallelism supports the following parallel styles:
- class torch.distributed.tensor.parallel.style.RowwiseParallel(_prepare_input=<function make_input_shard_1d_last_dim>, _prepare_output=<function make_output_tensor>)[source]¶
Partitioning the row of a module. We assume the input to be a sharded
DTensorand output to be atorch.Tensor.
- class torch.distributed.tensor.parallel.style.ColwiseParallel(_prepare_input=<function make_input_replicate_1d>, _prepare_output=<function make_sharded_output_tensor>)[source]¶
Partitioning the column of a tensor or module. We assume the input to be a replicated
DTensorand output to be a shardedtorch.Tensor.
- class torch.distributed.tensor.parallel.style.PairwiseParallel(_prepare_input=None, _prepare_output=None)[source]¶
PairwiseParallel concatenate colwise and rowwise styles as a fixed pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output need to be replicate DTensors.
Warning
PairwiseParallel does not support
nn.MultiheadAttention,nn.Transformerwell at this moment. One workaround is to applyColwiseParallelandRowwiseParallelto the components of transformer. We recommend to usePairwiseParallelonly for even-number-layer MLP for now.
Warning
Sequence Parallelism are still in experimental and no evaluation has been done.
- class torch.distributed.tensor.parallel.style.SequenceParallel[source]¶
SequenceParallel concatenate colwise and rowwise styles as a fixed pair together with sequence parallel like what Megatron-LM Sequence parallel (https://arxiv.org/pdf/2205.05198.pdf) is doing. We assume both input and output need to be sharded DTensors.
Warning
SequenceParallel does not support
nn.MultiheadAttention,nn.Transformerwell at this moment. One workaround is to applyColwiseParallelandRowwiseParallelto the components of transformer. We recommend to useSequenceParallelonly for even-number-layer MLP for now.
Since Tensor Parallelism is built on top of DTensor, we need to specify the input and output placement of the module with DTensors so it can expectedly interacts with the module before and after. The followings are functions used for input/output preparation:
- torch.distributed.tensor.parallel.style.make_input_replicate_1d(input, device_mesh=None)[source]¶
Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
- Parameters
input (Union[
torch.Tensor,DTensor]) – This input tensor will be replicated over the 1-DDeviceMesh.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be replicated. If noDeviceMeshis passed andinputis aDTensor,input.device_meshwill be used. IfDeviceMeshis not 1-D, an exception will be thrown. Default:None
- Returns
A
DTensorreplicated overdevice_mesh.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_input_reshard_replicate(input, device_mesh)[source]¶
To construct a Sharded DTensor from a tensor on different ranks and then convert to a replicate DTensor.
- Parameters
input (
torch.Tensor) – The input tensor on each rank which consists of a global DTensor sharded on dimension0over the 1-DDeviceMeshand then the sharded DTensor is converted to a replicate DTensor.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be sharded. IfDeviceMeshis not 1-D, an exception will be thrown. Default:None
- Returns
- A
DTensorsharded on dimension0overdevice_mesh and then converted to replicate.
- A
- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_input_shard_1d(input, device_mesh=None, dim=0)[source]¶
Shard input tensor on
dimover an 1-D device mesh. This function will be used in ParallelStyle.- Parameters
input (Union[
torch.Tensor,DTensor]) – Single tensor will be sharded on dimensiondimover the 1-DDeviceMesh.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be sharded. If noDeviceMeshis passed andinputis aDTensor, input.device_mesh will be used. IfDeviceMeshis not 1-D, an exception will be thrown. Default:Nonedim (int, optional) – The sharding dimension of
inputtensor. Default: 0
- Returns
A
DTensorsharded on dimensiondimoverdevice_mesh.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_input_shard_1d_last_dim(input, device_mesh=None)[source]¶
Wrapper func of
make_input_shard_1dwithdim= -1.- Parameters
input (Union[
torch.Tensor,DTensor]) – This single tensor will be sharded on the last dimension over the 1-DDeviceMesh.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be sharded. If noDeviceMeshis passed andinputis aDTensor, input.device_mesh will be used. IfDeviceMeshis not 1-D, an exception will be thrown. Default:None
- Returns
A
DTensorsharded on the last dimension overdevice_mesh.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_output_replicate_1d(output, device_mesh=None)[source]¶
Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
- Parameters
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object needed to replicate the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:None
- Returns
A
DTensorobject made replicate.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_output_reshard_tensor(output, device_mesh=None)[source]¶
Convert Output DTensor to a sharded DTensor and return the local tensor.
- Parameters
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object needed to shard the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:None
- Returns
A
torch.Tensorobject converted from output DTensor.- Return type
- torch.distributed.tensor.parallel.style.make_output_shard_1d(output, device_mesh=None, dim=0)[source]¶
Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
- Parameters
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object needed to shard the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:Nonedim (int) – Sharding dim for output. Default: 0
- Returns
A
DTensorobject sharded on the given dim.- Return type
DTensor
- torch.distributed.tensor.parallel.style.make_output_tensor(output, device_mesh=None)[source]¶
Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
- Parameters
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object which is needed to replicate the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:None
- Returns
A
torch.Tensorobject converted from output DTensor.- Return type
Currently, there are some constraints which makes it hard for the MultiheadAttention
module to work out of box for Tensor Parallelism, so we recommend users to try ColwiseParallel
and RowwiseParallel for each parameter. There might be some code changes needed now
since we are parallelizing on the head dim of the MultiheadAttention module.
We also support 2D parallelism, where we compose tensor parallelism with data parallelism.
To integrate with FullyShardedDataParallel,
users just need to call the following API explicitly:
- torch.distributed.tensor.parallel.fsdp.enable_2d_with_fsdp()[source]¶
The API registers the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP). We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor.
- Returns
A bool indicated whether extension registration succeeds or not.
- Return type
To integrate with DistributedDataParallel,
users just need to call the following API explicitly:
- torch.distributed.tensor.parallel.ddp.pre_dp_module_transform(module)[source]¶
Enable the composability between Tensor Parallelism (TP) and Data Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which are DTensors to local tensors before wrapping with data parallelism API. We then register two hooks, one for converting local tensors back to DTensor preforward and one to convert DTensors back to tensors after Forward. By integrating this way, we avoid any special handling of DTensor parameters by DDP and get DTensor’s gradients propagated back to DP, e.g. gradient buckets of DDP.
For now, this API only works with
DistributedDataParallel. It will later support other DP methods such as FSDP.- Parameters
module (
nn.Module) – Module which has been applied TP on.
- Example::
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform >>> >>> # Define the module. >>> m = module(...) >>> parallelize_module(m, PairwiseParallel()) >>> m = pre_dp_module_transform(m) >>> m = DDP(m) >>>