目录

使用Join上下文管理器进行输入不均匀的分布式训练

创建日期:2021年8月4日 | 最后更新日期:2023年1月9日 | 最后验证日期:2024年11月5日

作者: Andrew Gu

注意

edit 查看和编辑此教程在 github

注意

Join 是在 PyTorch 1.10 中作为原型功能引入的。此 API 可能会发生变化。

在这个教程中,你将看到:

  • 一个关于 Join 上下文管理器的概述。

  • 使用上下文管理器的一个示例,带有 DistributedDataParallel

  • 使用上下文管理器的示例,同时包含 DistributedDataParallelZeroRedundancyOptimizer

  • 传递关键字参数给上下文管理器的一个示例。

  • 深入了解 Join 上下文管理器的工作原理。

  • 一个示例,展示如何使一个玩具类与上下文管理器兼容。

什么是 Join?

开始使用分布式数据并行 - 基本用例中,您看到了使用DistributedDataParallel进行数据并行训练的一般框架。这会在每次反向传播时隐式调度所有归约操作,以同步各进程间的梯度。此类集体通信需要进程中所有进程的参与,因此如果某个进程的输入较少,则其他进程将挂起或报错(取决于后端)。更广泛地说,对于任何执行每迭代同步集体通信的类,这个问题都会存在。

Join 是一个上下文管理器,用于在每个进程的训练循环周围,以支持输入不均匀的训练。该上下文管理器允许那些提前耗尽输入的进程(即提前 join)影子化尚未加入的进程执行的集体通信。通信如何被影子化由钩子指定。

使用 JoinDistributedDataParallel

PyTorch 的 DistributedDataParallel 可与 Join 上下文管理器无缝配合使用。以下是一个使用示例:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中 print() 来自 rank 0 和 rank 1 的内容可能任意排序):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

DistributedDataParallel 提供了自己的 join() 上下文管理器 在引入此通用 Join 上下文管理器之前。在上面的例子中,使用 with Join([model]): 等效于使用 with model.join():。现有 DistributedDataParallel.join() 的一个限制是它不允许多个 参与的类,例如 DistributedDataParallelZeroRedundancyOptimizer 一起。

使用 JoinDistributedDataParallelZeroRedundancyOptimizer

The Join context manager works not only with a single class but also with multiple classes together. PyTorch’s ZeroRedundancyOptimizer is also compatible with the context manager, so here, we examine how to modify the previous example to use both DistributedDataParallel and ZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与之前相同的输出。值得注意的变化是 另外将 ZeroRedundancyOptimizer 实例传递到 Join()

传递关键字参数

类可能提供关键字参数,这些参数会在上下文管理器运行时修改其行为。例如,DistributedDataParallel 提供了一个参数 divide_by_initial_world_size,它决定了梯度是除以初始世界大小还是有效世界大小(即非连接排名的数量)。此类关键字参数可以直接传递到上下文管理器中。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递给上下文管理器的关键字参数会在所有参与类之间共享。这不应该是一个限制,因为我们不期望出现多个 Joinable 需要相同参数的不同设置的情况。尽管如此,这一点仍需注意。

如何运作 Join

现在我们已经看到了一些使用 Join 上下文管理器的初步示例,让我们更深入地了解它是如何工作的。这将为您提供更深入的理解,了解它所提供的全部功能,并为您创建自己的自定义类做好准备。在这里,我们将介绍 Join 类以及支持类 JoinableJoinHook

Joinable

首先,与 Join 上下文管理器兼容的类必须继承 自抽象基类 Joinable。特别是,Joinable 必须 实现:

  • join_hook(self, **kwargs) -> JoinHook

这将返回 JoinHook 实例用于 Joinable,确定如何 加入的进程应影子 Joinable 执行的每次迭代集体通信

  • join_device(self) -> torch.device

这将返回一个设备,供Join上下文管理器使用以执行 集体通信,例如torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这将返回由 Join 上下文管理器用于执行集体通信的过程组。

特别是,join_devicejoin_process_group是必需的属性,以确保上下文管理器可以在加入和未加入进程之间安排集体通信。一种用法是在每次迭代中使用all-reduce来计算每个进程上的未加入进程数量。另一种用法是实现throw_on_early_termination=True所需的机制,我们将在下面进一步解释。

DistributedDataParallelZeroRedundancyOptimizer 已经继承 自 Joinable 并实现了上述方法,这就是为什么我们可以在之前的示例中直接使用它们。

Joinable 类应该确保调用 Joinable 构造函数 因为它初始化了一个 JoinConfig 实例,该实例由 上下文管理器内部使用以确保正确性。这将被保存在每个 Joinable 中作为一个字段 _join_config

JoinHook

接下来,让我们分解 JoinHook 类。一个 JoinHook 提供了两个 进入上下文管理器的入口点:

  • main_hook(self) -> None

此钩子会在每个加入的秩存在未加入的秩时被重复调用。它的目的是模拟 Joinable 在每个训练迭代中执行的集体通信(例如在一个前向传递、反向传递和优化器步骤中)。

  • post_hook(self, is_last_joiner: bool) -> None

此钩子在所有进程都已加入后调用。它会传递一个额外的 bool 参数 is_last_joiner,该参数表示该进程是否是最后一个加入的。此参数可能对同步有帮助。

为了给出这些钩子可能是什么样子的具体示例,提供的 ZeroRedundancyOptimizer 主钩子在每次正常训练步骤中执行一次优化器步骤, 因为加入的秩仍然负责更新和同步其参数的分片,而提供的 DistributedDataParallel 后钩子 将最终更新的模型从最后一个加入的秩之一广播出去,以确保所有秩之间的一致性。

Join

最后,让我们看看这些如何融入 Join 类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的例子中看到的,构造函数接收一个参与训练循环的 Joinable 的列表。这些应该是每个迭代中执行集体通信的 类。

enable 是一个 bool,如果你知道不会有奇数输入,可以将其设置为 False,此时上下文管理器变得空洞,类似于 contextlib.nullcontext()。这也会在参与的 Joinable 中禁用与 join 相关的计算。

throw_on_early_termination 是一个 bool,可以设置为 True,以便在检测到不均匀输入时,每个等级立即引发异常。 这对于不符合上下文管理器要求的情况很有用,这在大多数情况下是当有来自不同类别的集体通信时发生,例如在使用 DistributedDataParallel 与具有 SyncBatchNorm 层的模型时。在这种情况下,应将此参数设置为 True,以便应用程序逻辑可以捕获异常并确定如何继续。

  • 核心逻辑发生在 __exit__() 方法中,该方法在存在未加入的rank时循环,调用每个 Joinable 的主钩子,然后在所有rank都加入后调用它们的后钩子。主钩子和后钩子按照 Joinable 的传入顺序进行迭代。

  • 上下文管理器需要未加入进程的心跳信号。因此, 每个 Joinable 类应在每次迭代的集体通信之前调用 Join.notify_join_context() 。上下文管理器将确保只有第一个 Joinable 实际发送心跳信号。

警告

如上文提到的 throw_on_early_terminationJoin 上下文管理器与某些类组合不兼容。JoinableJoinHook 必须可序列化,因为每个钩子在执行完后才会继续下一个。换句话说,两个钩子不能重叠。此外,目前主钩子和后钩子都是以确定性的顺序进行迭代的。如果这似乎是一个主要限制,我们可以修改API以允许自定义顺序。

让一个玩具类与 Join 一起工作

由于上一节介绍了几个概念,让我们通过一个玩具示例来实际看看。在这里,我们将实现一个类,该类统计在它的秩加入之前所有秩所看到的输入数量。这应该为您提供如何使您自己的类与 Join 上下文管理器兼容的基本思路。

具体来说,以下代码中,每个进程会打印出(1)它加入之前所有进程中看到的输入数量,以及(2)所有进程中的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于 rank 0 看到 5 个输入,而 rank 1 看到 6 个,因此输出结果为:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些需要强调的关键点:

  • 一个 Counter 实例每次迭代执行一次全部归约操作,因此 主钩子也执行一次全部归约操作以与其保持同步。

  • Counter 类在其 __call__() 方法的开头调用 Join.notify_join_context(), 因为这是在其每次迭代的集体通信(即 all-reduce)之前的一个位置。

  • The is_last_joiner argument is used to determine the broadcast source in the post-hooks.

  • 我们传递 sync_max_count 关键字参数给上下文管理器, 然后传递给 Counter 的 join 钩子。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源