Tensor Parallelism - torch.distributed.tensor.parallel¶
Tensor Parallelism(TP) is built on top of DistributedTensor(DTensor) and provides several Parallelism styles: Rowwise, Colwise and Pairwise Parallelism.
Warning
Tensor Parallelism APIs are experimental and subject to change.
The entrypoint to parallelize your nn.Module using Tensor Parallelism is:
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan, tp_mesh_dim=0)[source]¶
The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
ParallelStyle, which indicates how user wants the module or sub_module to be parallelized.User can also specify different parallel style per module fully qualifed name (FQN). The API supports 2D parallelism natively by accepting an n-dimension device_mesh and users just need to specify the dimension where we perform tensor parallelism on.
- Parameters:
module (
nn.Module) – Module to be parallelized.device_mesh (
DeviceMesh) – Object which describes the mesh topology of devices for the DTensor.parallelize_plan (Union[
ParallelStyle, Dict[str,ParallelStyle]]) – The plan used to parallelize the module. It can be either aParallelStyleobject which contains how we prepare input/output for Tensor Parallelism or it can be a dict of module FQN and its correspondingParallelStyleobject.tp_mesh_dim (int) – The dimension of
device_meshwhere we perform Tensor Parallelism on.
- Returns:
A
nn.Moduleobject parallelized.- Return type:
- Example::
>>> from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel >>> >>> # Define the module. >>> m = Model(...) >>> m = parallelize_module(m, PairwiseParallel()) >>>
Warning
PairwiseParallelcomes with constraints for now. If you need finer granularity, you need to pass in a dict of module FQN and parallel style instead.
Tensor Parallelism supports the following parallel styles:
- class torch.distributed.tensor.parallel.style.RowwiseParallel[source]¶
Partitioning the row of a module. We assume the input to be a sharded
DTensorand output to be a replicatedDTensor.
- class torch.distributed.tensor.parallel.style.ColwiseParallel[source]¶
Partitioning the column of a tensor or module. We assume the input to be a replicated
DTensorand output to be a shardedDTensor.
- class torch.distributed.tensor.parallel.style.PairwiseParallel[source]¶
PairwiseParallel concatenate colwise and rowwise styles as a fixed pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output needs to a replicate DTensor.
Warning
PairwiseParallel only supports
nn.Multihead Attention,nn.Transformeror even-number-layer MLP for now.
Since Tensor Parallelism is built on top of DTensor, we need to specify the input and output placement of the module with DTensors so it can expectedly interacts with the module before and after. The followings are functions used for input/output preparation:
- torch.distributed.tensor.parallel.style.make_input_replicate_1d(input, device_mesh=None)[source]¶
Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle.
- Parameters:
input (Union[
torch.Tensor,DTensor]) – This input tensor will be replicated over the 1-DDeviceMesh.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be replicated. If noDeviceMeshis passed andinputis aDTensor,input.device_meshwill be used. IfDeviceMeshis not 1-D, an exception will be thrown. Default:None
- Returns:
A
DTensorreplicated overdevice_mesh.- Return type:
DTensor
- torch.distributed.tensor.parallel.style.make_input_shard_1d(input, device_mesh=None, dim=0)[source]¶
Shard input tensor on
dimover an 1-D device mesh. This function will be used in ParallelStyle.- Parameters:
input (Union[
torch.Tensor,DTensor]) – Single tensor will be sharded on dimensiondimover the 1-DDeviceMesh.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be sharded. If noDeviceMeshis passed andinputis aDTensor, input.device_mesh will be used. IfDeviceMeshis not 1-D, an exception will be thrown. Default:Nonedim (int, optional) – The sharding dimension of
inputtensor. Default: 0
- Returns:
A
DTensorsharded on dimensiondimoverdevice_mesh.- Return type:
DTensor
- torch.distributed.tensor.parallel.style.make_input_shard_1d_last_dim(input, device_mesh=None)[source]¶
Wrapper func of
make_input_shard_1dwithdim= -1.- Parameters:
input (Union[
torch.Tensor,DTensor]) – This single tensor will be sharded on dimensiondimover the 1-DDeviceMesh.device_mesh (
DeviceMesh, optional) – The 1-D device mesh whereinputwill be sharded. If noDeviceMeshis passed andinputis aDTensor, input.device_mesh will be used. IfDeviceMeshis not 1-D, an exception will be thrown. Default:None
- Returns:
A
DTensorsharded on dimensiondimoverdevice_mesh.- Return type:
DTensor
- torch.distributed.tensor.parallel.style.make_output_replicate_1d(output, device_mesh=None)[source]¶
Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle.
- Parameters:
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object needed to replicate the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:None
- Returns:
A
DTensorobject made replicate.- Return type:
DTensor
- torch.distributed.tensor.parallel.style.make_output_tensor(output, device_mesh=None)[source]¶
Convert Output DTensor to a replicated DTensor first and then convert it to Tensor.
- Parameters:
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object which is needed to replicate the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:None
- Returns:
A
torch.Tensorobject converted from output DTensor.- Return type:
- torch.distributed.tensor.parallel.style.make_output_shard_1d(output, device_mesh=None, dim=0)[source]¶
Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle.
- Parameters:
output (
DTensor) – Output of module to be converted.device_mesh (
DeviceMesh, optional) – Object needed to shard the output and it needs to be a 1Ddevice_meshand we will throw exceptions if a non-1Ddevice_meshis passed in. If nodevice_meshis passed in, we will reuse the one from output. Default:Nonedim (int) – Sharding dim for output. Default: 0
- Returns:
A
DTensorobject sharded on the given dim.- Return type:
DTensor
Currently, there are some constraints which makes it hard for the nn.MultiheadAttention
module to work out of box for Tensor Parallelism, so we built this multihead_attention
module for Tensor Parallelism users. Also, in parallelize_module, we automatically
swap nn.MultiheadAttention to this custom module when specifying PairwiseParallel.
- class torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None, tp_size=1, self_attention=True)[source]¶
Multi-head Attention block from Transformer models. Since we need some customizations for the attention layer, we are writing a customized but mathematically equivalent attention module as defined in torch.nn.
Note that: We now only support the case when it’s self attention with limited input args and we also assume that the input tensor has a dimension of three. Although we do implement the logic for multihead attention, it was not fully tested.
We also enabled 2D parallelism to integrate with FullyShardedDataParallel.
Users just need to call the following API explicitly:
- torch.distributed.tensor.parallel.fsdp.enable_2d_with_fsdp()[source]¶
The API registers the extension which is needed for Tensor Parallelism (TP) to work with FullyShardedDataParallel (FSDP). We first parallelize parameters within one module or sub_modules based on a parallelize_plan and will let FSDP reshard the local tensor of distributed parameter which is essentially a DTensor.
- Returns:
A bool indicated whether extension registration succeeds or not.
- Return type: