目录

使用完全分片数据并行(FSDP)进行高级模型训练

创建时间:2024年10月31日 | 最后更新时间:2024年10月31日 | 最后验证时间:2024年11月5日

作者: Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

你将学到什么
  • PyTorch 的完全分片数据并行模块:一个用于在多个设备上分片模块参数的包装器

数据并行工作进程。

先决条件
  • PyTorch 1.12 或更高版本

  • 了解 FSDP API

本教程介绍了作为 PyTorch 1.12 版本一部分的完全分片数据并行 (FSDP) 的更多高级功能。要熟悉 FSDP,请参阅 FSDP 入门教程

在这个教程中,我们以文本摘要为例,使用 FSDP 对 HuggingFace(HF)的 T5 模型进行微调。

这个示例使用了Wikihow,并且为了简单起见,我们将在单个节点的P4dn实例上进行训练,该实例配备了8个A100 GPU。我们现在有几篇博客文章( (链接1), (链接2)) 和一篇关于在多节点集群上进行大规模FSDP训练的论文

FSDP 是一个可投入生产的包,注重易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 的内存占用。这使得与 DDP 相比,可以使用更低的总内存训练更大的模型,并利用计算与通信的重叠来高效地训练模型。 这种降低的内存压力可以用于训练更大的模型或增加批量大小,从而潜在地提高整体训练吞吐量。您可以在此阅读更多关于 PyTorch FSDP here

本教程中的FSDP特性

  • Transformer 自动包装策略

  • 混合精度

  • 在设备上初始化 FSDP 模型

  • 分片策略

  • 反向预取

  • 通过流式传输到CPU进行模型检查点保存

FSDP运作回顾

在高层次上,FDSP 的工作方式如下:

在构造函数中

  • 分片模型参数,每个进程仅保留自己的分片

在前向传播过程中

  • 运行 all_gather 以从所有rank收集所有碎片,以恢复此FSDP单元的完整参数并运行前向计算

  • 丢弃刚刚收集的非自有参数分片以释放内存

反向传播过程中

  • 运行 all_gather 以从所有rank收集所有碎片,以恢复此FSDP单元中的完整参数并运行反向计算

  • 丢弃未拥有的参数以释放内存。

  • 运行 reduce_scatter 来同步梯度

微调HF T5

HF T5 预训练模型有四种不同的规模,参数量从 6000 万到 110 亿不等。在本教程中,我们将演示如何使用 WikiHow 数据集对一个具有 30 亿参数的 T5 模型进行微调,并使用 FSDP 进行文本摘要。本教程的主要目的是突出 FSDP 中可用于训练超过 30 亿参数的大规模模型的不同功能。此外,我们还将介绍针对 Transformer 模型的具体功能。本教程的代码可以在 Pytorch 示例 中找到。

设置

1.1 安装最新版 PyTorch

pip3 install torch torchvision torchaudio

1.2 数据集设置

请创建一个 data 文件夹,从 wikihowAll.csv 下载 WikiHow 数据集 和 wikihowSep.cs, 并将它们放入 data 文件夹。我们将使用来自 summarization_dataset 的 wikihow 数据集。

接下来,我们将以下代码片段添加到 Python 脚本 “T5_training.py” 中。

注意

本教程的完整源代码可以在 PyTorch 示例 中找到。

1.3 导入必要的包:

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分布式训练设置。 在这里,我们使用两个辅助函数来初始化分布式训练的进程,然后在训练完成后进行清理。在本教程中,我们将使用 torchrun ,它将自动设置工作进程 RANKWORLD_SIZE

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 设置 HuggingFace T5 模型:

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

我们还在这里添加几个辅助函数,用于日期处理和格式化内存指标。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定义训练函数:

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定义验证函数:

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定义一个分布式训练函数,该函数将模型包装在 FSDP 中:

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

2.5 解析参数并设置主函数:

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

使用 torchrun 运行训练:

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 包装策略

如前一篇教程中所述, auto_wrap_policy 是 FSDP 功能之一,它使得自动分片给定模型并将模型、优化器和梯度分片放入 不同的 FSDP 单元变得容易。

对于一些架构,例如 Transformer 编码器-解码器,在这种情况下,模型的一些部分(如嵌入表)会被同时共享给编码器和解码器。在这种情况下,我们需要将嵌入表放置在外部的 FSDP 单元中,以便可以从编码器和解码器访问。此外,通过为 Transformer 注册层类,可以使分片计划更加通信高效。在 PyTorch 1.12 中,FSDP 增加了对这一功能的支持,现在我们有了针对 Transformer 的包装策略。

它可以按如下方式创建,其中 T5Block 表示 T5 transformer 层类(包含 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包装后的模型,你可以直接打印模型,并直观地检查分片和 FSDP 单元。

混合精度

FSDP 支持灵活的混合精度训练,允许使用任意降低精度类型(如 fp16 或 bfloat16)。目前 BFloat16 仅在 Ampere GPU 上可用,因此在使用之前需要确认是否原生支持。例如,在 V100 上,BFloat16 仍然可以运行,但由于不是原生支持,可能会导致显著的性能下降。

要检查 BFloat16 是否原生支持,您可以使用以下方法:

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

FSDP 中混合精度的一个优势是能够对参数、梯度和缓冲区的不同精度级别提供精细的控制,如下所示:

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

请注意,如果未指定某种类型(参数、缩减操作、缓冲区),它们将完全不会被转换。

这种灵活性使用户能够进行精细的控制,例如仅在降低精度的情况下进行梯度通信,而所有参数/缓冲区的计算则在全精度下进行。这在节点内通信是主要瓶颈的情况下可能非常有用,此时参数/缓冲区必须使用全精度以避免准确性问题。可以通过以下策略实现:

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在 2.4 中,我们只需将相关的混合精度策略添加到 FSDP 包装器中:

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我们的实验中,我们观察到使用 BFloat16 进行训练可以实现最多 4 倍的速度提升,并且在一些实验中内存减少约 30%,这可用于增加批量大小。

初始化设备上的FSDP模型

在1.12版本中,FSDP支持一个 device_id 参数,用于在由 device_id 指定的设备上初始化输入CPU模块。这在整个模型无法放入单个GPU但可以放入主机CPU内存时非常有用。当指定 device_id 时,FSDP将根据每个FSDP单元将模型移动到指定设备,从而避免GPU内存不足的问题,同时比基于CPU的初始化快数倍:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略

FSDP 默认的分片策略是完全分片模型参数、梯度和优化器状态,这些内容会在所有进程中进行分片。(也称为 Zero3 分片)。如果你对使用 Zero2 分片策略感兴趣,其中仅对优化器状态和梯度进行分片,FSDP 支持此功能,可以通过传递分片策略来实现,使用 "ShardingStrategy.SHARD_GRAD_OP",而不是 "ShardingStrategy.FULL_SHARD" 作为 FSDP 初始化参数,如下所示:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

这将减少 FSDP 中的通信开销,在这种情况下,它在前向传播和反向传播过程中保持完整的参数。

这在反向传播过程中节省了一次 all_gather 操作,从而减少了通信开销,但以更高的内存占用为代价。请注意,在反向传播结束时会释放完整的模型参数,并且 all_gather 操作将在下一次前向传播时发生。

反向预取

反向预取设置控制下一个FSDP单元的参数请求时机。通过将其设置为BACKWARD_PRE,下一个FSDP单元的参数可以在当前单元计算开始之前更早地被请求并到达。这会重叠all_gather的通信和梯度计算,以略微增加的内存消耗为代价提高训练速度。它可以在2.4版本的FSDP包装器中使用如下:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 有两种模式,BACKWARD_PREBACKWARD_POSTBACKWARD_POST 表示下一个 FSDP 单元的参数将在当前 FSDP 单元处理完成之后才会被请求,从而最小化内存开销。在某些情况下,使用 BACKWARD_PRE 可以提高模型训练速度高达 2-10%,对于更大的模型甚至有更高的加速效果。

模型检查点保存,通过流式传输到Rank0 CPU

使用FULL_STATE_DICT保存模型检查点,该方式以与本地模型相同的方式保存模型,PyTorch 1.12 提供了一些工具来支持更大模型的保存。

首先,可以指定一个 FullStateDictConfig,允许仅在 rank 0 上填充 state_dict,并将其卸载到 CPU。

当使用此配置时,FSDP 会收集所有模型参数,并仅在 rank 0 上逐个将它们卸载到 CPU。当最终保存 state_dict 时,它只会被 rank 0 填充,并且包含 CPU 张量。这可以避免对于超过单个 GPU 显存的模型可能出现的内存不足问题,并允许用户检查点大小大致等于其机器上可用 CPU 内存的模型。

此功能可以按以下方式运行:

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

摘要

在这个教程中,我们介绍了 Pytorch 1.12 中可用的许多 FSDP 新特性,并以 HF T5 作为示例。使用适当的包装策略,特别是针对 Transformer 模型,结合混合精度和反向预取应该可以加快您的训练过程。此外,诸如在设备上初始化模型以及通过流式传输到 CPU 进行检查点保存等功能,有助于避免处理大型模型时出现内存溢出错误。

我们正在积极努力为下一个版本的FSDP添加新功能。如果您有任何反馈、功能请求、问题或在使用FSDP时遇到任何问题,请随时通过在 PyTorch Github 仓库中创建一个问题来联系我们。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源