Distributed Optimizers¶
-
class
torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source]¶ This class wraps an arbitrary
optim.Optimizerand shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately1 / world_sizeparameters and hence only needs to keep1 / world_sizeoptimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state.ZeroRedundancyOptimizercan be used in conjunction withtorch.nn.parallel.DistributedDataParallelto reduce per-rank peak memory consumption.ZeroRedundancyOptimizeruses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.- Parameters
params (
Iterable) – anIterableoftorch.Tensors giving all parameters, which will be sharded across ranks.- Keyword Arguments
optimizer_class (
torch.nn.Optimizer) – the class of the local optimizer.process_group (
ProcessGroup, optional) –torch.distributedProcessGroup(default:dist.group.WORLDinitialized bytorch.distributed.init_process_group()).parameters_as_bucket_view (bool, optional) – if
True, parameters are packed into buckets to speed up communication, andparam.datafields point to bucket views at different offsets; ifFalse, each individual parameter is communicated separately, and eachparams.datastays intact (default:False).overlap_with_ddp (bool, optional) – if
True,step()is overlapped withDistributedDataParallel‘s gradient synchronization; this requires (1) either a functional optimizer for theoptimizer_classargument or one with a functional equivalent and (2) registering a DDP communication hook constructed from one of the functions inddp_zero_hook.py; parameters are packed into buckets matching those inDistributedDataParallel, meaning that theparameters_as_bucket_viewargument is ignored. IfFalse,step()runs disjointly after the backward pass (per normal). (default:False)**defaults – any trailing arguments, which are forwarded to the local optimizer.
Example:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
Warning
Currently,
ZeroRedundancyOptimizerrequires that all of the passed-in parameters are the same dense type.Warning
If you pass
overlap_with_ddp=True, be wary of the following: Given the way that overlappingDistributedDataParallelwithZeroRedundancyOptimizeris currently implemented, the first two or three training iterations do not perform parameter updates in the optimizer step, depending on ifstatic_graph=Falseorstatic_graph=True, respectively. This is because it needs information about the gradient bucketing strategy used byDistributedDataParallel, which is not finalized until the second forward pass ifstatic_graph=Falseor until the third forward pass ifstatic_graph=True. To adjust for this, one option is to prepend dummy inputs.Warning
ZeroRedundancyOptimizer is experimental and subject to change.
-
add_param_group(param_group)[source]¶ Add a parameter group to the
Optimizer‘sparam_groups.This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the
Optimizeras training progresses.- Parameters
param_group (dict) – specifies the parameters to be optimized and group-specific optimization options.
Warning
This method handles updating the shards on all partitions but needs to be called on all ranks. Calling this on a subset of the ranks will cause the training to hang because communication primitives are called depending on the managed parameters and expect all the ranks to participate on the same set of parameters.
-
consolidate_state_dict(to=0)[source]¶ Consolidate a list of
state_dicts (one per rank) on the target rank.- Parameters
to (int) – the rank that receives the optimizer states (default: 0).
- Raises
RuntimeError – if
overlap_with_ddp=Trueand this method is called before thisZeroRedundancyOptimizerinstance has been fully initialized, which happens onceDistributedDataParallelgradient buckets have been rebuilt.
Warning
This needs to be called on all ranks.
-
join_hook(**kwargs)[source]¶ Returns the ZeRO join hook, which enables training on uneven inputs by shadowing the collective communications in the optimizer step.
Gradients must be properly set before this hook is called.
- Parameters
kwargs (dict) – a
dictcontaining any keyword arguments to modify the behavior of the join hook at run time; allJoinableinstances sharing the same join context manager are forwarded the same value forkwargs.
This hook does not support any keyword arguments; i.e.
kwargsis unused.
-
load_state_dict(state_dict)[source]¶ Load the state pertaining to the given rank from the input
state_dict, updating the local optimizer as needed.- Parameters
state_dict (dict) – optimizer state; should be an object returned from a call to
state_dict().- Raises
RuntimeError – if
overlap_with_ddp=Trueand this method is called before thisZeroRedundancyOptimizerinstance has been fully initialized, which happens onceDistributedDataParallelgradient buckets have been rebuilt.
-
state_dict()[source]¶ Returns the last global optimizer state known to this rank.
- Raises
RuntimeError – if
overlap_with_ddp=Trueand this method is called before thisZeroRedundancyOptimizerinstance has been fully initialized, which happens onceDistributedDataParallelgradient buckets have been rebuilt; or if this method is called without a preceding call toconsolidate_state_dict().