使用Join上下文管理器进行输入不均匀的分布式训练¶
创建日期:2021年8月4日 | 最后更新日期:2023年1月9日 | 最后验证日期:2024年11月5日
作者: Andrew Gu
注意
查看和编辑此教程在 github。
注意
Join 是在 PyTorch 1.10 中作为原型功能引入的。此
API 可能会发生变化。
在这个教程中,你将看到:
一个关于 Join 上下文管理器的概述。
使用上下文管理器的一个示例,带有
DistributedDataParallel。使用上下文管理器的示例,同时包含
DistributedDataParallel和ZeroRedundancyOptimizer。传递关键字参数给上下文管理器的一个示例。
深入了解 Join 上下文管理器的工作原理。
一个示例,展示如何使一个玩具类与上下文管理器兼容。
需求¶
PyTorch 1.10+
什么是 Join?¶
在开始使用分布式数据并行 - 基本用例中,您看到了使用DistributedDataParallel进行数据并行训练的一般框架。这会在每次反向传播时隐式调度所有归约操作,以同步各进程间的梯度。此类集体通信需要进程中所有进程的参与,因此如果某个进程的输入较少,则其他进程将挂起或报错(取决于后端)。更广泛地说,对于任何执行每迭代同步集体通信的类,这个问题都会存在。
Join 是一个上下文管理器,用于在每个进程的训练循环周围,以支持输入不均匀的训练。该上下文管理器允许那些提前耗尽输入的进程(即提前 join)影子化尚未加入的进程执行的集体通信。通信如何被影子化由钩子指定。
使用 Join 与 DistributedDataParallel¶
PyTorch 的 DistributedDataParallel 可与 Join
上下文管理器无缝配合使用。以下是一个使用示例:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
with Join([model]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
这将产生以下输出(其中 print() 来自 rank 0 和 rank 1 的内容可能任意排序):
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
注意
DistributedDataParallel 提供了自己的 join() 上下文管理器
在引入此通用 Join 上下文管理器之前。在上面的例子中,使用 with Join([model]): 等效于使用
with model.join():。现有 DistributedDataParallel.join() 的一个限制是它不允许多个
参与的类,例如 DistributedDataParallel 和
ZeroRedundancyOptimizer 一起。
使用 Join 与 DistributedDataParallel 和 ZeroRedundancyOptimizer¶
The Join context manager works not only with a single class but also with
multiple classes together. PyTorch’s ZeroRedundancyOptimizer is also
compatible with the context manager, so here, we examine how to modify the
previous example to use both DistributedDataParallel and
ZeroRedundancyOptimizer:
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
optim = ZeRO(model.parameters(), Adam, lr=0.01)
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Pass both `model` and `optim` into `Join()`
with Join([model, optim]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
optim.step()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
这将产生与之前相同的输出。值得注意的变化是
另外将 ZeroRedundancyOptimizer 实例传递到
Join()。
传递关键字参数¶
类可能提供关键字参数,这些参数会在上下文管理器运行时修改其行为。例如,DistributedDataParallel 提供了一个参数 divide_by_initial_world_size,它决定了梯度是除以初始世界大小还是有效世界大小(即非连接排名的数量)。此类关键字参数可以直接传递到上下文管理器中。
with Join([model, optim], divide_by_initial_world_size=False):
for input in inputs:
...
警告
传递给上下文管理器的关键字参数会在所有参与类之间共享。这不应该是一个限制,因为我们不期望出现多个 Joinable 需要相同参数的不同设置的情况。尽管如此,这一点仍需注意。
如何运作 Join?¶
现在我们已经看到了一些使用 Join
上下文管理器的初步示例,让我们更深入地了解它是如何工作的。这将为您提供更深入的理解,了解它所提供的全部功能,并为您创建自己的自定义类做好准备。在这里,我们将介绍 Join 类以及支持类 Joinable 和 JoinHook。
Joinable¶
首先,与 Join 上下文管理器兼容的类必须继承
自抽象基类 Joinable。特别是,Joinable 必须
实现:
join_hook(self, **kwargs) -> JoinHook
这将返回 JoinHook 实例用于 Joinable,确定如何
加入的进程应影子 Joinable 执行的每次迭代集体通信
join_device(self) -> torch.device
这将返回一个设备,供Join上下文管理器使用以执行
集体通信,例如torch.device("cuda:0")或
torch.device("cpu")。
join_process_group(self) -> ProcessGroup
这将返回由 Join 上下文管理器用于执行集体通信的过程组。
特别是,join_device和join_process_group是必需的属性,以确保上下文管理器可以在加入和未加入进程之间安排集体通信。一种用法是在每次迭代中使用all-reduce来计算每个进程上的未加入进程数量。另一种用法是实现throw_on_early_termination=True所需的机制,我们将在下面进一步解释。
DistributedDataParallel 和 ZeroRedundancyOptimizer 已经继承
自 Joinable 并实现了上述方法,这就是为什么我们可以在之前的示例中直接使用它们。
Joinable 类应该确保调用 Joinable 构造函数
因为它初始化了一个 JoinConfig 实例,该实例由
上下文管理器内部使用以确保正确性。这将被保存在每个
Joinable 中作为一个字段 _join_config。
JoinHook¶
接下来,让我们分解 JoinHook 类。一个 JoinHook 提供了两个
进入上下文管理器的入口点:
main_hook(self) -> None
此钩子会在每个加入的秩存在未加入的秩时被重复调用。它的目的是模拟 Joinable 在每个训练迭代中执行的集体通信(例如在一个前向传递、反向传递和优化器步骤中)。
post_hook(self, is_last_joiner: bool) -> None
此钩子在所有进程都已加入后调用。它会传递一个额外的
bool 参数 is_last_joiner,该参数表示该进程是否是最后一个加入的。此参数可能对同步有帮助。
为了给出这些钩子可能是什么样子的具体示例,提供的
ZeroRedundancyOptimizer 主钩子在每次正常训练步骤中执行一次优化器步骤,
因为加入的秩仍然负责更新和同步其参数的分片,而提供的 DistributedDataParallel 后钩子
将最终更新的模型从最后一个加入的秩之一广播出去,以确保所有秩之间的一致性。
Join¶
最后,让我们看看这些如何融入 Join 类本身。
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
正如我们在前面的例子中看到的,构造函数接收一个参与训练循环的
Joinable 的列表。这些应该是每个迭代中执行集体通信的
类。
enable 是一个 bool,如果你知道不会有奇数输入,可以将其设置为 False,此时上下文管理器变得空洞,类似于 contextlib.nullcontext()。这也会在参与的 Joinable 中禁用与 join 相关的计算。
throw_on_early_termination 是一个 bool,可以设置为 True,以便在检测到不均匀输入时,每个等级立即引发异常。
这对于不符合上下文管理器要求的情况很有用,这在大多数情况下是当有来自不同类别的集体通信时发生,例如在使用 DistributedDataParallel 与具有 SyncBatchNorm 层的模型时。在这种情况下,应将此参数设置为 True,以便应用程序逻辑可以捕获异常并确定如何继续。
核心逻辑发生在
__exit__()方法中,该方法在存在未加入的rank时循环,调用每个Joinable的主钩子,然后在所有rank都加入后调用它们的后钩子。主钩子和后钩子按照Joinable的传入顺序进行迭代。上下文管理器需要未加入进程的心跳信号。因此, 每个
Joinable类应在每次迭代的集体通信之前调用Join.notify_join_context()。上下文管理器将确保只有第一个Joinable实际发送心跳信号。
警告
如上文提到的 throw_on_early_termination,Join 上下文管理器与某些类组合不兼容。Joinable 的 JoinHook 必须可序列化,因为每个钩子在执行完后才会继续下一个。换句话说,两个钩子不能重叠。此外,目前主钩子和后钩子都是以确定性的顺序进行迭代的。如果这似乎是一个主要限制,我们可以修改API以允许自定义顺序。
让一个玩具类与 Join¶ 一起工作
由于上一节介绍了几个概念,让我们通过一个玩具示例来实际看看。在这里,我们将实现一个类,该类统计在它的秩加入之前所有秩所看到的输入数量。这应该为您提供如何使您自己的类与 Join 上下文管理器兼容的基本思路。
具体来说,以下代码中,每个进程会打印出(1)它加入之前所有进程中看到的输入数量,以及(2)所有进程中的总输入数量。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
由于 rank 0 看到 5 个输入,而 rank 1 看到 6 个,因此输出结果为:
10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
一些需要强调的关键点:
一个
Counter实例每次迭代执行一次全部归约操作,因此 主钩子也执行一次全部归约操作以与其保持同步。这
Counter类在其__call__()方法的开头调用Join.notify_join_context(), 因为这是在其每次迭代的集体通信(即 all-reduce)之前的一个位置。The
is_last_joinerargument is used to determine the broadcast source in the post-hooks.我们传递
sync_max_count关键字参数给上下文管理器, 然后传递给Counter的 join 钩子。