torch.distributed.tensor¶
Note
torch.distributed.tensor is currently in alpha state and under
development, we are committing backward compatibility for the most APIs listed
in the doc, but there might be API changes if necessary.
PyTorch DTensor (Distributed Tensor)¶
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed
logic, including sharded storage, operator computation and collective communications across devices/hosts.
DTensor could be used to build different paralleism solutions and support sharded state_dict representation
when working with multi-dimensional sharding.
Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor:
DTensor follows the SPMD (single program, multiple data) programming model to empower users to
write distributed program as if it’s a single-device program with the same convergence property. It
provides a uniform tensor sharding layout (DTensor Layout) through specifying the DeviceMesh
and Placement:
DeviceMeshrepresents the device topology and the communicators of the cluster using an n-dimensional array.Placementdescribes the sharding layout of the logical tensor on theDeviceMesh. DTensor supports three types of placements:Shard,ReplicateandPartial.
DTensor Class APIs¶
DTensor is a torch.Tensor subclass. This means once a DTensor is created, it could be
used in very similar way to torch.Tensor, including running different types of PyTorch operators as if
running them in a single device, allowing proper distributed computation for PyTorch operators.
In addition to existing torch.Tensor methods, it also offers a set of additional methods to interact with
torch.Tensor, redistribute the DTensor Layout to a new DTensor, get the full tensor content
on all devices, etc.
- class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor(Distributed Tensor) is a subclass oftorch.Tensorthat provides single-device like abstraction to program with multi-devicetorch.Tensor. It describes the distributed tensor sharding layout (DTensor Layout) through theDeviceMeshand following types ofPlacement:Shard: Tensor sharded on the tensor dimensiondimon the devices of theDeviceMeshdimensionReplicate: Tensor replicated on the devices of theDeviceMeshdimensionPartial: Tensor is pending reduction on the devices of theDeviceMeshdimension
When calling PyTorch operators,
DTensoroverrides the PyTorch operators to perform sharded computation and issue communications whenever necessary. Along with the operator computation,DTensorwill transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate newDTensoroutputs.To ensure numerical correctness of the
DTensorsharded computation when calling PyTorch operators,DTensorrequires every Tensor argument of the operator be DTensor.- Return type
- property device_mesh: DeviceMesh¶
The
DeviceMeshattribute that associates with this DTensor object.Note
device_meshis a read-only property, it can not be set.
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]¶
Create a
DTensorfrom a local torch.Tensor on each rank according to thedevice_meshandplacementsspecified.- Parameters
local_tensor (torch.Tensor) – local torch.Tensor on each rank.
device_mesh (
DeviceMesh, optional) – DeviceMesh to place the tensor, if not specified, must be called under a DeviceMesh context manager, default: Noneplacements (List[
Placement], optional) – the placements that describes how to place the local torch.Tensor on DeviceMesh, must have the same number of elements asdevice_mesh.ndim.
- Keyword Arguments
run_check (bool, optional) – at a cost of extra communications, perform sanity check across ranks to check each local tensor’s meta information to ensure correctness. If have
Replicateinplacements, the data on first rank of the device mesh dimension will be broadcasted to other ranks. default: Falseshape (torch.Size, optional) – A List of int which specifies the size of DTensor which build on top of local_tensor. Note this needs to be provided if the shape of
local_tensorare different across the ranks. If not provided,shapewill be computed assuming the given distributed tensor is evenly sharded across ranks. default: Nonestride (tuple, optional) – A List of int which specifies the stride of DTensor. If not provided,
stridewill be computed assuming the given distributed tensor is evenly sharded across ranks. default: None
- Returns
A
DTensorobject- Return type
Note
When
run_check=False, it is the user’s responsibility to ensure the local tensor passed in is correct across ranks (i.e. the tensor is sharded for theShard(dim)placement or replicated for theReplicate()placement). If not, the behavior of the created DTensor is undefined.Note
from_localis differentiable, the requires_grad of the created DTensor object will depend on if local_tensor requires_grad or not.
- full_tensor(*, grad_placements=None)[source][source]¶
Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. It’s a syntatic sugar of the following code:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()- Keyword Arguments
grad_placements (List[
Placement], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function. full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor might not be used as the original replicated DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original replicated DTensor layout. If not specified, we will assume the gradient layout of the full tensor be replicated.- Returns
A
torch.Tensorobject that represents the full tensor of this DTensor.- Return type
Note
full_tensoris differentiable.
- property placements: Tuple[Placement, ...]¶
The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.
Note
placementsis a read-only property, it can not be set.
- redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]¶
redistributeperforms necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from is current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh.When redistributing from current to the new placements on one device mesh dimension, we will perform the following operations including communication collective or local operation:
Shard(dim)->Replicate():all_gatherShard(src_dim)->Shard(dst_dim):all_to_allReplicate()->Shard(dim): local chunking (i.e.torch.chunk)Partial()->Replicate():all_reducePartial()->Shard(dim):reduce_scatter
redistributewould correctly figure out the necessary redistribute steps for DTensors that are created either on 1-D or N-D DeviceMesh.- Parameters
device_mesh (
DeviceMesh, optional) – DeviceMesh to place the DTensor. If not specified, it would use the current DTensor’s DeviceMesh. default: Noneplacements (List[
Placement], optional) – the new placements that describes how to place the DTensor into the DeviceMesh, must have the same number of elements asdevice_mesh.ndim. default: replicate on all mesh dimensions
- Keyword Arguments
async_op (bool, optional) – whether to perform the DTensor redistribute operation asynchronously or not. Default: False
- Returns
A
DTensorobject- Return type
Note
redistributeis differentiable, which means user do not need to worry about the backward formula of the redistribute operation.Note
redistributecurrently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh.
- to_local(*, grad_placements=None)[source][source]¶
Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.
- Keyword Arguments
grad_placements (List[
Placement], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function. to_local converts DTensor to local tensor and the returned local tensor might not be used as the original DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original DTensor layout. If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.- Returns
A
torch.TensororAsyncCollectiveTensorobject. it represents the local tensor on its current rank. When anAsyncCollectiveTensorobject is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to callwaitto wait the local tensor to be ready.- Return type
Note
to_localis differentiable, therequires_gradof the local tensor returned will depend on if the DTensor requires_grad or not.
DeviceMesh as the distributed communicator¶
DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and represent
multi-dimensional communicators (on top of ProcessGroup). To see the details of how to create/use a DeviceMesh,
please refer to the DeviceMesh recipe.
DTensor Placement Types¶
DTensor supports the following types of Placement on each DeviceMesh dimension:
- class torch.distributed.tensor.placement_types.Shard(dim)[source][source]¶
The
Shard(dim)placement describes the DTensor sharding on tensor dimensiondimover a correspondingDeviceMeshdimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim)placement follows thetorch.chunk(dim)semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. TheShardplacement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)- Parameters
dim (int) – The tensor dimension that describes the DTensor is sharded over its corresponding DeviceMesh dimension.
Warning
sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
- class torch.distributed.tensor.placement_types.Replicate[source][source]¶
The
Replicate()placement describes the DTensor replicating on a correspondingDeviceMeshdimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. TheReplicateplacement can be used by all DTensor APIs (i.e.distribute_tensor,DTensor.from_local, etc.)
- class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]¶
The
Partial(reduce_op)placement describes the DTensor that is pending reduction on a specifiedDeviceMeshdimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute thePartialDTensor to aReplicateorShard(dim)placement on the specifiedDeviceMeshdimension usingredistribute, which would trigger necessary communication operations under the hood (i.e.allreduce,reduce_scatter).- Parameters
reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.
Note
The
Partialplacement can be generated as a result of the DTensor operators, and can only be used by theDTensor.from_localAPI.
- class torch.distributed.tensor.placement_types.Placement[source][source]¶
The base class for the Placement type, where it describes how a DTensor is placed onto the
DeviceMesh.PlacementandDeviceMeshtogether could describe the DTensor Layout. It is the base class of the three main DTensor Placement types:Shard,Replicate, andPartial.This class is not meant to be used directly, mainly served as a typing stub.
Different ways to create a DTensor¶
- There’re three ways to construct a
DTensor: distribute_tensor()creates aDTensorfrom a logical or “global”torch.Tensoron each rank. This could be used to shard the leaftorch.Tensors (i.e. model parameters/buffers and inputs).DTensor.from_local()creates aDTensorfrom a localtorch.Tensoron each rank, which can be used to createDTensorfrom a non-leaftorch.Tensors (i.e. intermediate activation tensors during forward/backward).DTensor provides dedicated tensor factory functions (e.g.
empty(),ones(),randn(), etc.) to allow differentDTensorcreations by directly specifying theDeviceMeshandPlacement. Compare todistribute_tensor(), this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.
Create DTensor from a logical torch.Tensor¶
The SPMD (single program, multiple data) programming model in torch.distributed launches multiple processes
(i.e. via torchrun) to execute the same program, this means that the model inside the program would be
initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly
on GPU if enough memory).
DTensor offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor s,
where it would create a DTensor from the “logical” Tensor on each process. This would empower the created
DTensor s to comply with the single device semantic, which is critical for numerical correctness.
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None)[source]¶
Distribute a leaf
torch.Tensor(i.e. nn.Parameter/buffers) to thedevice_meshaccording to theplacementsspecified. The rank ofdevice_meshandplacementsmust be the same. Thetensorto distribute is the logical or “global” tensor, and the API would use thetensorfrom first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd computation, please useDTensor.from_local()instead.- Parameters
tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use
torch.chunksemantic to shard the tensor and scatter the shards. The uneven sharding behavior is experimental and subject to change.device_mesh (
DeviceMesh, optional) – DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context manager, default: Noneplacements (List[
Placement], optional) – the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements asdevice_mesh.ndim. If not specified, we will by default replicate the tensor across thedevice_meshfrom the first rank of each dimension of the device_mesh.
- Returns
A
DTensororXLAShardedTensorobject.- Return type
Note
When initialize the DeviceMesh with the
xladevice_type,distribute_tensorreturn XLAShardedTensor instead. see this issue for more details. The XLA integration is experimental and subject to change.
Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier
sharding on the nn.Module level
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]¶
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying the
partition_fn(i.e. allow user to convert Module parameters toDTensorparameters according to the partition_fn specified). 2. To control the inputs or outputs of the module during runtime execution by specifying theinput_fnandoutput_fn. (i.e. convert the input toDTensor, convert the output back totorch.Tensor)- Parameters
module (
nn.Module) – user module to be partitioned.device_mesh (
DeviceMesh) – the device mesh to place the module.partition_fn (Callable) – the function to partition parameters (i.e. shard certain parameters across the
device_mesh). Ifpartition_fnis not specified, by default we replicate all module parameters ofmoduleacross the mesh.input_fn (Callable) – specify the input distribution, i.e. could control how the input of the module is sharded.
input_fnwill be installed as a moduleforward_pre_hook(pre forward hook).output_fn (Callable) – specify the output distribution, i.e. could control how the output is sharded, or convert it back to torch.Tensor.
output_fnwill be installed as a moduleforward_hook(post forward hook).
- Returns
A module that contains parameters/buffers that are all
DTensors.- Return type
Note
When initialize the DeviceMesh with the
xladevice_type,distribute_modulereturn nn.Module with PyTorch/XLA SPMD annotated parameters. See this issue for more details. The XLA integration is experimental and subject to change.
DTensor Factory Functions¶
DTensor also provides dedicated tensor factory functions to allow creating DTensor directly
using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally
specifying the DeviceMesh and Placement for the DTensor created:
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a
DTensorfilled with the scalar value 0.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))- Keyword Arguments
requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor. Default:False.dtype (
torch.dtype, optional) – the desired data type of returnedDTensor. Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returnedDTensor. Default:torch.strided.device_mesh –
DeviceMeshtype, contains the mesh info of ranksplacements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a
DTensorfilled with the scalar value 1, with the shape defined by the variable argumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor. Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor. Default:torch.strided.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranksplacements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a
DTensorfilled with uninitialized data. The shape of theDTensoris defined by the variable argumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor. Default: ifNone, uses a global default (seetorch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returnedDTensor. Default:torch.strided.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranksplacements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a
DTensorfilled withfill_valueaccording todevice_meshandplacements, with the shape defined by the argumentsize.- Parameters
- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor. Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor. Default:torch.strided.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranks.placements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a
DTensorfilled with random numbers from a uniform distribution on the interval[0, 1). The shape of the tensor is defined by the variable argumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor. Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor. Default:torch.strided.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranks.placements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a
DTensorfilled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the tensor is defined by the variable argumentsize.- Parameters
size (int...) – a sequence of integers defining the shape of the output
DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))- Keyword Arguments
dtype (
torch.dtype, optional) – the desired data type of returnedDTensor. Default: ifNone, uses a global default (seetorch.set_default_dtype()).layout (
torch.layout, optional) – the desired layout of returned DTensor. Default:torch.strided.requires_grad (bool, optional) – If autograd should record operations on the returned
DTensor. Default:False.device_mesh –
DeviceMeshtype, contains the mesh info of ranks.placements – a sequence of
Placementtype:Shard,Replicate
- Returns
A
DTensorobject on each rank- Return type
Debugging¶
Logging¶
When launching the program, you can turn on additional logging using the TORCH_LOGS environment variable from torch._logging :
TORCH_LOGS=+dtensor will display logging.DEBUG messages and all levels above it.
TORCH_LOGS=dtensor will display logging.INFO messages and above.
TORCH_LOGS=-dtensor will display logging.WARNING messages and above.
Debugging Tools¶
To debug the program that applied DTensor, and understand more details about what collectives happened under the
hood, DTensor provides a CommDebugMode:
- class torch.distributed.tensor.debug.CommDebugMode¶
CommDebugModeis a context manager that counts the number of functional collectives within its context. It does this using aTorchDispatchMode.Example usage
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source][source]¶
Generates detailed table displaying operations and collective tracing information on a module level. Amount of information is dependent on noise_level
prints module-level collective counts
prints dTensor operations not included in trivial operations, module information
prints operations not included in trivial operations
prints all operations
- generate_json_dump(file_name='comm_mode_log.json', noise_level=3)[source][source]¶
Creates json file used to build browser visual 0. prints module-level collective counts 1. prints dTensor operations not included in trivial operations 2. prints operations not included in trivial operations 3. prints all operations
To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides visualize_sharding():
Experimental Features¶
DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basic
functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to
these features.
- torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]¶
context_parallelis an experimental API to enable context parallelism (CP). This API performs two actions: 1) patch the SDPA (torch.nn.functional.scaled_dot_product_attention) with the CP-enabled one, 2) shardbuffersalong the sequence dimension and each rank will preserve the corresponding shard accordingmesh.- Parameters
mesh (
DeviceMesh) – the device mesh for the context parallelism.buffers (Optional[List[torch.Tensor]]) – buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will happen in-place, the buffer’s shape will change within the context. The buffers will be restored after the context finishes.
no_restore_bufferscan be used to specify which buffers don’t need to be restored. Note thatbuffersshould not contain any nn.Parameter.buffer_seq_dims (Optional[List[int]]) – the sequence dimensions of
buffers.no_restore_buffers (Optional[Set[torch.Tensor]]) – buffers in these set won’t be restored after the context exits. This set must be a subset of
buffers. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time.
- Return type
Generator[None, None, None]
Warning
torch.distributed._tensor.experimental.attention.context_parallel is a prototype feature in PyTorch. The API is subject to change.
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]¶
local_map()is an experimental API that allows users to passDTensors to a function that is written to be applied ontorch.Tensors. It is done by extracting the local components ofDTensor, call the function, and wrap the outputs toDTensoraccording to theout_placements.- Parameters
func (Callable) – the function to be applied on each local shard of
DTensors.out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of the
DTensors infunc’s flattened output. If the flattenedoutputis a single value, theout_placementsshould be of type PlacementType. Otherwise if the flattenedoutputhas multiple values, theout_placementsshould be a tuple of PlacementType values 1:1 mapping to the flattenedoutput. Besides, forTensoroutput, we use PlacementType as its placements (a Tuple[Placement] value). For non-Tensor output, the PlacementType should be None. Note that the only exception is when noDTensorargument is passed in. In this case, even if out_placements is not None, the result function should ignore the desired placements because the function is not running withDTensors.in_placements (Tuple[PlacementType, …], optional) – the required placements of the
DTensors in the flattened inputs offunc. Ifin_placementsis specified,local_map()would examine whether the placements of eachDTensorargument is the same as the required placements or not. If the placements are not the same andredistribute_inputsisFalse, an exception will be raised. Otherwise ifredistribute_inputsisTrue, the argument will be first redistributed to the required sharding placements before passing its local tensor tofunc. The only exception is when required placements are notNoneand the argument is atorch.Tensor. In this case, the placements examination will be skipped and the argument will be directly passed tofunc. Ifin_placementsisNone, no placements examination will be performed. Default: Nonedevice_mesh (
DeviceMesh, optional) – the device mesh that all theDTensors are placed on. If not specified, this will be inferred from the inputDTensors’ device mesh. local_map requires everyDTensors to be placed on the same device mesh. Default: None.redistribute_inputs (bool, optional) – the bool value indicating whether to reshard the input
DTensors when their placements are different from the required input placements. If this value isFalseand someDTensorinput has a different placement, an exception will be raised. Default: False.
- Returns
A
Callablethat appliesfuncto each local shard of the inputDTensorand returns aDTensorconstructed from the return value offunc.- Raises
AssertionError – If the input
DTensoris not placed on the same device mesh, or if they are placed on a different device mesh than thedevice_meshargument passed in.AssertionError – For any non-DTensor output, we require its corresponding output placement in
out_placementsbe None. An AssertionError will be raised if this is not the case.ValueError – If
redistribute_inputs=Falsebut the inputDTensorneeds a redistribution according toin_placements.
Example
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
Note
This API is currently experimental and subject to change
- torch.distributed.tensor.experimental.register_sharding(op)[source]¶
register_sharding()is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn’t exist a default sharding strategy forop, e.g. whenopis a custom operator that is not supported byDTensor; (2) when users would like to overwrite default sharding strategies of existing operators.- Parameters
op (Union[OpOverload, List[OpOverload]]) – An op or a list of ops to register the customized sharding function.
- Returns
A function decorator which can be used to wrap a function that defines the sharding strategy for the operator specified in
op. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is atorch.Tensor, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its corresponding intput placements.
Example
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
Note
This API is currently experimental and subject to change