目录

TorchRec简介

创建时间:2024年10月02日 | 最后更新时间:2024年10月10日 | 最后验证时间:2024年10月02日

TorchRec 是一个专为使用嵌入构建可扩展且高效的推荐系统而设计的 PyTorch 库。 本教程将引导您完成安装过程,介绍嵌入的概念,并强调其在推荐系统中的重要性。 它提供了使用 PyTorch 和 TorchRec 实现嵌入的实践演示,重点在于通过分布式训练和高级优化处理大型嵌入表。

你将学到什么
  • 嵌入的基本概念及其在推荐系统中的作用

  • 如何设置 TorchRec 以在 PyTorch 环境中管理并实现嵌入

  • 探索在多个 GPU 上分布大型嵌入表的高级技术

先决条件
  • PyTorch v2.5 或更高版本,配合 CUDA 11.8 或更高版本

  • Python 3.9 或更高版本

  • FBGEMM

安装依赖项

在 Google Colab 或其他环境中运行此教程之前,请安装以下依赖项:

!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

注意

如果你在 Google Colab 中运行此代码,请确保切换到 GPU 运行时类型。 更多信息,请参阅 启用 CUDA

嵌入

在构建推荐系统时,分类特征通常具有巨大的基数,如帖子、用户、广告等。

为了表示这些实体并建模这些关系, 嵌入 被使用。在机器学习中,嵌入是用于表示复杂数据(如单词、图像或用户)含义的高维空间中的实数向量

嵌入在推荐系统中

现在你可能会想,这些嵌入是如何生成的呢?好吧,嵌入被表示为嵌入表中的单独行,也称为嵌入权重。原因在于,嵌入或嵌入表权重是通过梯度下降像模型的其他权重一样进行训练的!

Embedding tables 是一个用于存储嵌入的大型矩阵,具有两个维度 (B, N),其中:

  • B 是该表存储的嵌入向量数量

  • N 是每个嵌入的维度数(N 维嵌入)。

嵌入表的输入表示嵌入查找,用于检索特定索引或行的嵌入。在推荐系统中,如许多大型系统所使用的那样,唯一ID不仅用于特定用户,还用于帖子和广告等实体,作为查找相应嵌入表的索引!

嵌入向量是通过以下过程在推荐系统中进行训练的:

  • 输入/查找索引被输入到模型中,作为唯一的ID。ID会被哈希到嵌入表的总大小,以防止当ID大于行数时出现问题

  • 嵌入向量随后被检索并进行池化,例如取嵌入向量的总和或平均值。这是必需的,因为每个示例可能包含不同数量的嵌入向量,而模型期望一致的形状。

  • 这些嵌入与模型的其余部分结合使用以生成预测,例如广告的点击率 (CTR)

  • 损失是通过预测和标签计算得出的 例如,并且 模型的所有权重都通过梯度下降和反向传播进行更新,包括与该示例相关联的嵌入权重

这些嵌入对于表示分类特征(如用户、帖子和广告)以捕捉关系并做出良好的推荐至关重要。 深度学习推荐模型 (DLRM) 论文详细讨论了在推荐系统中使用嵌入表的技术细节。

本教程介绍了嵌入的概念,展示 TorchRec 特定的模块和数据类型,并说明 TorchRec 中分布式训练的工作方式。

import torch

PyTorch中的嵌入

在 PyTorch 中,我们有以下类型的嵌入:

  • torch.nn.Embedding: 一个嵌入表,前向传递返回嵌入本身。

  • torch.nn.EmbeddingBag: 嵌入表,其中前向传递返回嵌入,然后进行池化,例如求和或取平均值,也称为 池化嵌入

在本节中,我们将简要介绍如何通过将索引传递到表中来执行嵌入查找。

num_embeddings, embedding_dim = 10, 4

# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)

# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)

embeddings = embedding_collection(ids)

# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)

print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936],
        [0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294],
        [0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516, 0.0753],
        [0.8860, 0.5832, 0.3376, 0.8090],
        [0.5779, 0.9040, 0.5547, 0.3423]])
Embedding Collection Table:  Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936],
        [0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294],
        [0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516, 0.0753],
        [0.8860, 0.5832, 0.3376, 0.8090],
        [0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Embedding Bag Collection Table:  Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936],
        [0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294],
        [0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516, 0.0753],
        [0.8860, 0.5832, 0.3376, 0.8090],
        [0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Input row IDS:  tensor([[1, 3]])
Embedding Collection Results:
tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
         [0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
Shape:  torch.Size([1, 2, 4])
Embedding Bag Collection Results:
tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
Shape:  torch.Size([1, 4])
Mean:  tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)

恭喜!现在你已经基本了解了如何使用嵌入表 —— 现代推荐系统的基础之一!这些表表示实体及其关系。例如,某个用户与他们喜欢的页面和帖子之间的关系。

TorchRec 特性概览

在上面的部分中,我们已经学习了如何使用嵌入表,这是现代推荐系统的基础之一!这些表表示实体和关系,例如用户、页面、帖子等。鉴于这些实体总是不断增加,通常会应用一个哈希函数,以确保ID在某个嵌入表的范围内。然而,为了表示大量实体并减少哈希冲突,这些表可能会变得非常庞大(例如考虑广告的数量)。事实上,这些表可能会变得如此庞大,即使有80G的内存,也无法装入1块GPU中。

为了训练具有大规模嵌入表的模型,需要将这些表分片到多个GPU上,这随后在并行性和优化方面引入了一整套新的问题和机遇。幸运的是,我们有TorchRec库,它已经遇到了、整合了并解决了许多这些问题。TorchRec作为一个提供大规模分布式嵌入原语的库

接下来,我们将探索TorchRec库的主要功能。我们将从torch.nn.Embedding开始,并将其扩展到自定义TorchRec模块,探讨使用生成嵌入分片计划的分布式训练环境,查看内在的TorchRec优化,并将模型扩展以准备在C++中进行推理。以下是本节内容的快速概览:

  • TorchRec 模块和数据类型

  • 分布式训练、分片和优化

  • 推理

让我们从导入 TorchRec 开始:

import torchrec

此部分介绍了TorchRec模块和数据类型,包括诸如EmbeddingCollectionEmbeddingBagCollectionJaggedTensorKeyedJaggedTensorKeyedTensor等实体。

EmbeddingBagEmbeddingBagCollection

我们已经探讨了 torch.nn.Embeddingtorch.nn.EmbeddingBag。 TorchRec 通过创建嵌入集合来扩展这些模块,换句话说,就是可以拥有多个嵌入表的模块, EmbeddingCollectionEmbeddingBagCollection 我们将使用 EmbeddingBagCollection 来表示一组嵌入袋。

在下面的示例代码中,我们创建了一个 EmbeddingBagCollection (EBC) 包含两个嵌入包,1表示 产品,1表示 用户。 每个表格,product_tableuser_table,由一个 64 维 嵌入组成,大小为 4096。

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)
print(ebc.embedding_bags)
ModuleDict(
  (product_table): EmbeddingBag(4096, 64, mode='sum')
  (user_table): EmbeddingBag(4096, 64, mode='sum')
)

让我们检查 EmbeddingBagCollection 的 forward 方法以及模块的输入和输出:

import inspect

# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
    """
    Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
    and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.

    Args:
        features (KeyedJaggedTensor): Input KJT
    Returns:
        KeyedTensor
    """
    flat_feature_names: List[str] = []
    for names in self._feature_names:
        flat_feature_names.extend(names)
    inverse_indices = reorder_inverse_indices(
        inverse_indices=features.inverse_indices_or_none(),
        feature_names=flat_feature_names,
    )
    pooled_embeddings: List[torch.Tensor] = []
    feature_dict = features.to_dict()
    for i, embedding_bag in enumerate(self.embedding_bags.values()):
        for feature_name in self._feature_names[i]:
            f = feature_dict[feature_name]
            res = embedding_bag(
                input=f.values(),
                offsets=f.offsets(),
                per_sample_weights=f.weights() if self._is_weighted else None,
            ).float()
            pooled_embeddings.append(res)
    return KeyedTensor(
        keys=self._embedding_names,
        values=process_pooled_embeddings(
            pooled_embeddings=pooled_embeddings,
            inverse_indices=inverse_indices,
        ),
        length_per_key=self._lengths_per_embedding,
    )

TorchRec 输入/输出数据类型

TorchRec 为模块的输入和输出定义了不同的数据类型: JaggedTensor, KeyedJaggedTensor, 和 KeyedTensor。现在你 可能会问,为什么创建新的数据类型来表示稀疏特征?要 回答这个问题,我们必须了解稀疏特征在代码中的表示方式。

稀疏特征也被称为 id_list_featureid_score_list_feature,它们是将用作嵌入表索引的 ID,以检索该 ID 的嵌入。为了给出一个非常简单的例子,想象一个单一的稀疏特征是用户互动过的广告。输入本身是一组用户互动过的广告 ID,检索到的嵌入是对这些广告的语义表示。在代码中表示这些特征的难点在于,每个输入示例中 ID 的数量是可变的。有一天用户可能只与一个广告互动,而第二天他们可能与三个广告互动。

如下所示是一个简单的表示,其中我们有一个 lengths 张量,表示一个批次中一个示例包含多少个索引,以及一个 values 张量,包含这些索引本身。

# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])

# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])

接下来,我们再来看偏移量以及每个批次中包含的内容

# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
    "Second Batch: ",
    id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)

from torchrec import JaggedTensor

# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)

# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())

# Convert to list of values
print("List of Values: ", jt.to_dense())

# ``__str__`` representation
print(jt)

from torchrec import KeyedJaggedTensor

# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())

# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())

# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())

# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())

# ``KeyedJaggedTensor`` string representation
print(kjt)

# Q2: What are the offsets for the ``KeyedJaggedTensor``?

# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result

# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())

# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)

# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)
Offsets:  tensor([1, 3])
First Batch:  tensor([5])
Second Batch:  tensor([7, 1])
Offsets:  tensor([0, 1, 3])
List of Values:  [tensor([5]), tensor([7, 1])]
JaggedTensor({
    [[5], [7, 1]]
})

Keys:  ['product', 'user']
Lengths:  tensor([3, 1, 2, 2])
Values:  tensor([1, 2, 1, 5, 2, 3, 4, 1])
to_dict:  {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f4edb57c940>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f4edb57c9a0>}
KeyedJaggedTensor({
    "product": [[1, 2, 1], [5]],
    "user": [[2, 3], [4, 1]]
})

['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])

恭喜!你现在已经了解了TorchRec模块和数据类型。 给自己一个拥抱,因为你已经走到了这一步。接下来,我们将 学习分布式训练和分片的相关内容。

分布式训练和分片

现在我们已经了解了TorchRec模块和数据类型,是时候将其提升到一个新的水平了。

请记住,TorchRec 的主要目的是提供分布式嵌入的原始组件。到目前为止,我们只在单个设备上处理嵌入表。这在嵌入表非常小的情况下是可行的,但在生产环境中通常并非如此。嵌入表往往变得非常庞大,以至于单个 GPU 无法容纳一张表,这就要求使用多个设备和分布式环境。

在本节中,我们将探讨如何设置分布式环境,深入了解实际生产训练的具体实现方式,并探索分片嵌入表,所有内容都将使用 TorchRec 实现。

本节也将仅使用 1 个 GPU,尽管它将以分布式方式处理。这仅是训练的限制,因为训练每个 GPU 都有一个独立的进程。推理不会遇到这个要求

在下面的示例代码中,我们设置了 PyTorch 分布式环境。

警告

如果你是在 Google Colab 中运行此代码,你只能调用这个单元格一次,再次调用会导致错误,因为进程组只能初始化一次。

import os

import torch.distributed as dist

# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")

print(f"Distributed environment initialized: {dist}")
Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>

分布式嵌入

我们已经与主要的TorchRec模块: EmbeddingBagCollection 进行了工作。我们已经检查了它的运作方式以及TorchRec中数据的表示方式。然而,我们尚未探索TorchRec的一个主要部分,即分布式嵌入

GPU是目前机器学习工作负载最受欢迎的选择,因为它们每秒可以执行的浮点运算(FLOPs)远多于CPU。然而,GPU存在快速内存(HBM,类似于CPU的RAM)稀缺的限制,通常只有几十GB。

一个推荐系统模型可以包含远超单个GPU内存限制的嵌入表,因此需要将嵌入表分布到多个GPU上,这种做法也称为模型并行。另一方面,数据并行是指将整个模型复制到每个GPU上,每个GPU接收不同的数据批次进行训练,在反向传播过程中同步梯度。

模型中需要较少计算但更多内存(嵌入)的部分采用模型并行分布,而需要更多计算和较少内存(密集层、MLP等)的部分则采用数据并行分布。

分片

为了分发一个嵌入表,我们将嵌入表拆分成多个部分,并将这些部分放置在不同的设备上,这种方法也被称为“分片”。

有很多方法可以分片嵌入表。最常见的方法有:

  • 表级:整个表被放置在一个设备上

  • 列级:嵌入表的列被分片

  • 行级:嵌入表的行被分片

分片模块

虽然所有这些看起来像是很多需要处理和实现的内容,但你很幸运。 TorchRec 提供了所有用于轻松分布式训练和推理的基础组件!实际上,TorchRec 模块为在分布式环境中使用任何 TorchRec 模块提供了两个对应的类:

  • 模块分片器: 该类暴露了一个 shard API 用于处理TorchRec模块的分片,生成一个分片后的模块。 * 对于 EmbeddingBagCollection,分片器是 EmbeddingBagCollectionSharder

  • 分片模块: 该类是TorchRec模块的分片变体。 它与常规TorchRec模块具有相同的输入/输出,但优化程度更高,并且可以在分布式环境中运行。 * 对于 EmbeddingBagCollection,分片变体是 ShardedEmbeddingBagCollection

每个 TorchRec 模块都有一个未分片和分片的变体。

  • 非分片版本旨在进行原型设计和实验。

  • 分片版本旨在用于分布式环境中进行分布式训练和推理。

TorchRec模块的分片版本,例如 EmbeddingBagCollection,将处理模型并行所需的一切内容,例如在GPU之间进行通信以将嵌入分布到正确的GPU上。

我们 EmbeddingBagCollection 模块的复习

ebc

from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv

# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()

# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"

print(f"Process Group: {pg}")
Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f50a8c5cf30>

计划器

在展示分片如何工作之前,我们必须了解 规划器,它帮助我们确定最佳的分片配置。

给定多个嵌入表和多个秩,存在许多不同的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,你可以:

  • 每张 GPU 上放置 1 张表

  • 将两个表放在一个 GPU 上,另一个 GPU 上不放任何表

  • 将某些行和列放置在每个 GPU 上

考虑到所有这些可能性,我们通常希望有一个对性能最优化的分片配置。

这就是规划器的作用。规划器可以根据嵌入表的数量和 GPU 的数量,确定最优的配置。事实上,手动完成这项工作非常困难,工程师需要考虑大量因素以确保最佳的分片计划。幸运的是,当使用规划器时,TorchRec 提供了一个自动规划器。

TorchRec 计划器:

  • 评估硬件的内存限制

  • 估计基于内存获取进行计算,如嵌入查找

  • 处理数据特定因素

  • 考虑其他硬件特性,如带宽,以生成最优的分片计划

为了考虑所有这些变量,TorchRec规划器可以接收不同数量的数据用于嵌入表、约束条件、硬件信息和拓扑结构,以帮助生成模型的最佳分片计划,这些计划通常在各个堆栈中提供。

要了解更多关于分片的信息,请参阅我们的分片教程

# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device="cuda",
    )
)

# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)

print(f"Sharding Plan generated: {plan}")
Sharding Plan generated: module:

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise    | fused          | [0]
user_table    | table_wise    | fused          | [0]

    param     | shard offsets | shard sizes |   placement
------------- | ------------- | ----------- | -------------
product_table | [0, 0]        | [4096, 64]  | rank:0/cuda:0
user_table    | [0, 0]        | [4096, 64]  | rank:0/cuda:0

计划结果

如你上面所见,在运行规划器时会有相当多的输出信息。 我们可以看到很多统计数据被计算出来,以及我们的 表最终被放置的位置。

运行规划器的结果是一个静态计划,可以用于分片!这允许生产模型的分片为静态,而不是每次确定一个新的分片计划。下面,我们使用分片计划最终生成我们的 ShardedEmbeddingBagCollection

# The static plan that was generated
plan

env = ShardingEnv.from_process_group(pg)

# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

print(f"Sharded EBC Module: {sharded_ebc}")
Sharded EBC Module: ShardedEmbeddingBagCollection(
  (lookups):
   GroupedPooledEmbeddingsLookup(
      (_emb_modules): ModuleList(
        (0): BatchedFusedEmbeddingBag(
          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
        )
      )
    )
   (_output_dists):
   TwPooledEmbeddingDist()
  (embedding_bags): ModuleDict(
    (product_table): Module()
    (user_table): Module()
  )
)

GPU训练 with LazyAwaitable

请记住,TorchRec 是一个高度优化的分布式嵌入库。TorchRec 引入了一个概念,以在 GPU 上训练时实现更高的性能,即 LazyAwaitable。 您将看到 LazyAwaitable 种类型作为各种分片 TorchRec 模块的输出。所有 LazyAwaitable 类型所做的就是尽可能延迟计算某些结果,并通过像异步类型一样运作来实现这一点。

from typing import List

from torchrec.distributed.types import LazyAwaitable


# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
    def __init__(self, size: List[int]) -> None:
        super().__init__()
        self._size = size

    def _wait_impl(self) -> torch.Tensor:
        return torch.ones(self._size)


awaitable = ExampleAwaitable([3, 2])
awaitable.wait()

kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)

kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))

print(kt.keys())

print(kt.values().shape)

# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)
<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7f4ed824a9e0>
<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])

PyTorch Rec 分片模块结构

我们现在已经成功地将一个EmbeddingBagCollection按照我们生成的分片计划进行了分片!分片后的模块具有来自TorchRec的通用API,这些API抽象了多块GPU之间的分布式通信/计算。实际上,这些API在训练和推理中都高度优化以提升性能。以下是TorchRec提供的三个用于分布式训练/推理的通用API

  • input_dist: 负责将输入从GPU分发到GPU。

  • lookups: 在优化的、批量方式下实际执行嵌入查找,使用FBGEMM TBE(后面会详细介绍)。

  • output_dist: 负责将GPU的输出分发到各个GPU。

输入和输出的分布是通过 NCCL Collectives, 即 All-to-Alls, 其中所有GPU互相发送和接收数据。 TorchRec 与 PyTorch 分布式进行集成,用于集合操作,并为最终用户提供清晰的抽象,消除了对底层细节的担忧。

反向传播执行所有这些集体操作,但顺序与梯度分布相反。 input_distlookupoutput_dist 都依赖于分片方案。由于我们以表格方式分片,这些API是由 TwPooledEmbeddingSharding 构建的模块。

sharded_ebc

# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists

# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
[TwPooledEmbeddingDist(
  (_dist): PooledEmbeddingsAllToAll()
)]

优化嵌入查找

在对一组嵌入表进行查找时,一个简单的解决方案是遍历所有的 nn.EmbeddingBags 并对每个表执行一次查找。这正是标准的、未分片的 EmbeddingBagCollection 所做的。然而,虽然这种解决方案简单,但速度极其缓慢。

FBGEMM 是一个 提供高度优化的 GPU 操作符(也称为内核)的库。 其中一个操作符被称为 表批量嵌入 (TBE),提供了两项主要优化:

  • 表批量处理,允许您通过一次内核调用来查找多个嵌入。

  • 优化器融合,允许模块根据标准的 PyTorch 优化器和参数进行自我更新。

The ShardedEmbeddingBagCollection 使用 FBGEMM TBE 作为查找 替代传统的 nn.EmbeddingBags 以优化嵌入查找。

sharded_ebc._lookups
[GroupedPooledEmbeddingsLookup(
  (_emb_modules): ModuleList(
    (0): BatchedFusedEmbeddingBag(
      (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
    )
  )
)]

DistributedModelParallel

我们现在已经探索了对单个 EmbeddingBagCollection 进行分片!我们能够将 EmbeddingBagCollectionSharder 与未分片的 EmbeddingBagCollection 结合,生成一个 ShardedEmbeddingBagCollection 模块。这个流程是可行的,但在实现模型并行时,通常会使用 DistributedModelParallel (DMP) 作为标准接口。当你用 DMP 包裹你的模型(在我们的情况下是 ebc)时,会发生以下情况:

  1. 决定如何分片模型。DMP 将收集可用的分片器,并制定最优的分片方案(例如,EmbeddingBagCollection

  2. 实际上对模型进行分片。这包括在适当的设备上为每个嵌入表分配内存。

DMP 包含了我们刚刚实验过的一切内容,例如静态分片计划、分片器列表等。然而,它还具有一些不错的默认设置,可以无缝地对 TorchRec 模型进行分片。在这个玩具示例中,由于我们有两个嵌入表和一个 GPU,TorchRec 会将它们都放在单个 GPU 上。

ebc

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

out = model(kjt)
out.wait()

model
WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.

DistributedModelParallel(
  (_dmp_wrapped_module): ShardedEmbeddingBagCollection(
    (lookups):
     GroupedPooledEmbeddingsLookup(
        (_emb_modules): ModuleList(
          (0): BatchedFusedEmbeddingBag(
            (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
          )
        )
      )
     (_input_dists):
     TwSparseFeaturesDist(
        (_dist): KJTAllToAll()
      )
     (_output_dists):
     TwPooledEmbeddingDist(
        (_dist): PooledEmbeddingsAllToAll()
      )
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
  )
)

分片最佳实践

目前,我们的配置仅在1个GPU(或rank)上进行分片,这很简单:只需将所有表放在1个GPU的内存中。然而,在实际生产用例中,嵌入表通常会在数百个GPU上进行分片,使用不同的分片方法,如按表分片、按行分片和按列分片。确定适当的分片配置(以防止内存不足问题)非常重要,同时不仅要考虑内存平衡,还要考虑计算平衡,以实现最佳性能。

优化器添加

请记住,TorchRec 模块针对大规模分布式训练进行了高度优化。一个重要优化涉及优化器。

TorchRec 模块提供了一个无缝的 API,用于在训练中融合反向传播和优化步骤,从而显著提升性能并减少内存使用,同时允许对不同的模型参数分配不同的优化器,实现更细粒度的控制。

优化器类

TorchRec 使用 CombinedOptimizer,其中包含一组 KeyedOptimizers。一个 CombinedOptimizer 有效地使处理模型中各个子组的多个优化器变得容易。一个 KeyedOptimizer 扩展了 torch.optim.Optimizer,并通过参数字典进行初始化,暴露参数。 每个 TBE 模块在 EmbeddingBagCollection 中都会有它自己的 KeyedOptimizer,这些组合成一个 CombinedOptimizer

TorchRec中的融合优化器

使用 DistributedModelParallel优化器被融合,这意味着优化器更新在反向传播中进行。这是TorchRec和FBGEMM中的一个优化,其中优化器的嵌入梯度不会被显式生成并直接应用于参数。这带来了显著的内存节省,因为嵌入梯度的大小通常与参数本身相当。

您可以选择将优化器设置为dense,这不会应用此优化,并允许您检查嵌入梯度或根据需要对其进行计算。在这种情况下,密集优化器将是您的标准 PyTorch 模型训练循环与优化器。

一旦通过 DistributedModelParallel 创建了优化器,你 仍然需要为与TorchRec嵌入模块无关的其他参数管理一个优化器。要找到这些其他 参数, 使用 in_backward_optimizer_filter(model.named_parameters())。 像使用普通的Torch优化器一样将优化器应用于这些参数,并将此与 model.fused_optimizer 结合成一个 CombinedOptimizer,你可以在训练循环中使用它来 zero_gradstep

EmbeddingBagCollection 添加优化器

我们将通过两种方式来实现,这两种方式是等效的,但会根据您的偏好提供不同的选项:

  1. 通过 fused_params 在分片器中传递优化器 kwargs。

  2. 通过 apply_optimizer_in_backward,将优化器 参数转换为 fused_params 传递给 TBEEmbeddingBagCollectionEmbeddingCollection 中。

# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType


# We initialize the sharder with
fused_params = {
    "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
    "learning_rate": 0.02,
    "eps": 0.002,
}

# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)

# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))

# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")

print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it

# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}

for name, param in ebc_apply_opt.named_parameters():
    print(f"{name=}")
    apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)

sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))

# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))

# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())

# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")

loss.backward()

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.01
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.02
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
/var/lib/workspace/intermediate_source/torchrec_intro_tutorial.py:876: DeprecationWarning:

`TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.

name='embedding_bags.product_table.weight'
name='embedding_bags.user_table.weight'
: EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.5
)
<class 'torchrec.optim.keyed.CombinedOptimizer'>
Non Fused Model Parameters:
dict_keys([])
First Iteration Loss: 255.66006469726562
Second Iteration Loss: 245.43795776367188

推理

现在我们能够训练分布式嵌入,如何将训练好的模型优化以用于推理?推理通常非常敏感于模型的性能和大小。仅仅在Python环境中运行训练好的模型效率极其低下。 推理与训练环境之间有两个关键区别:

  • 量化: 推理模型通常会被量化,其中模型参数会损失精度以降低预测的延迟和减少模型大小。例如,将训练模型中的FP32(4字节)转换为每个嵌入权重的INT8(1字节)。鉴于嵌入表的巨大规模,这也势在必行,因为我们希望尽可能少地使用设备进行推理以最小化延迟。

  • C++环境: 推理延迟非常重要,因此为了确保充分性能,模型通常在C++环境中运行, 以及在没有Python运行时的情况,例如在设备上。

TorchRec 提供了将 TorchRec 模型转换为推理就绪的 primitives,包括:

  • 模型量化API,介绍使用FBGEMM TBE自动引入优化的方法

  • 分布式推理中的切分嵌入

  • 将模型编译为 TorchScript (兼容C++)

在本节中,我们将回顾整个工作流程:

  • 模型量化

  • 对量化模型进行分片

  • 将分片量化模型编译成 TorchScript

ebc

class InferenceModule(torch.nn.Module):
    def __init__(self, ebc: torchrec.EmbeddingBagCollection):
        super().__init__()
        self.ebc_ = ebc

    def forward(self, kjt: KeyedJaggedTensor):
        return self.ebc_(kjt)

module = InferenceModule(ebc)
for name, param in module.named_parameters():
    # Here, the parameters should still be FP32, as we are using a standard EBC
    # FP32 is default, regularly used for training
    print(name, param.shape, param.dtype)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32

量化

如上所示,正常的 EBC 包含以 FP32 精度(每个权重 32 位)存储的嵌入表权重。在这里,我们将使用 TorchRec 推理库将模型的嵌入权重量化为 INT8

from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
    EmbeddingBagCollection as QuantEmbeddingBagCollection,
)


quant_dtype = torch.int8


qconfig = QuantConfig(
    # dtype of the result of the embedding lookup, post activation
    # torch.float generally for compatibility with rest of the model
    # as rest of the model here usually isn't quantized
    activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
    # quantized type for embedding weights, aka parameters to actually quantize
    weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
    # Map of module type to qconfig
    torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
    # Map of module type to quantized module type
    torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}


module = InferenceModule(ebc)

# Quantize the module
qebc = quant.quantize_dynamic(
    module,
    qconfig_spec=qconfig_spec,
    mapping=mapping,
    inplace=False,
)


print(f"Quantized EBC: {qebc}")

kjt = kjt.to("cpu")

qebc(kjt)

# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
    # The shapes of the tables should be the same but the dtype should be int8 now
    # post quantization
    print(name, buffer.shape, buffer.dtype)
Quantized EBC: InferenceModule(
  (ebc_): QuantizedEmbeddingBagCollection(
    (_kjt_to_jt_dict): ComputeKJTToJTDict()
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
  )
)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8

分片

在这里我们对TorchRec量化模型进行分片。这是为了确保我们使用通过FBGEMM TBE实现的高性能模块。这里我们使用一个设备,以与训练时保持一致(1个TBE)。

from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules


sharded_qebc = _shard_modules(
    module=qebc,
    device=torch.device("cpu"),
    env=trec_dist.ShardingEnv.from_local(
        1,
        0,
    ),
)


print(f"Sharded Quantized EBC: {sharded_qebc}")

sharded_qebc(kjt)
WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.
Sharded Quantized EBC: InferenceModule(
  (ebc_): ShardedQuantEmbeddingBagCollection(
    (lookups):
     InferGroupedPooledEmbeddingsLookup()
    (_output_dists): ModuleList()
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
    (_input_dist_module): ShardedQuantEbcInputDist()
  )
)

<torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7f4ed828c310>

编译

现在我们已经有了优化后的 eager TorchRec 推理模型。下一步是确保该模型可以在 C++ 中加载,因为目前它只能在 Python 运行时中运行。

Meta推荐的编译方法有两个方面:torch.fx 跟踪(生成模型的中间表示)和将结果转换为TorchScript,其中TorchScript与C++兼容。

from torchrec.fx import Tracer


tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])

graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)

print("Graph Module Created!")

print(gm.code)

scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")

print(scripted_gm.code)
Graph Module Created!

torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")

def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
    flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt);  kjt = None
    _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths);  _fx_marker = None
    split = flatten_feature_lengths.split([2])
    getitem = split[0];  split = None
    to = getitem.to(device(type='cuda', index=0), non_blocking = True);  getitem = None
    _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths);  flatten_feature_lengths = _fx_marker_1 = None
    _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to);  to = None
    getitem_1 = _unwrap_kjt[0]
    getitem_2 = _unwrap_kjt[1]
    getitem_3 = _unwrap_kjt[2];  _unwrap_kjt = getitem_3 = None
    _tensor_constant0 = self._tensor_constant0
    _tensor_constant1 = self._tensor_constant1
    bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None);  _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
    _tensor_constant2 = self._tensor_constant2
    _tensor_constant3 = self._tensor_constant3
    _tensor_constant4 = self._tensor_constant4
    _tensor_constant5 = self._tensor_constant5
    _tensor_constant6 = self._tensor_constant6
    _tensor_constant7 = self._tensor_constant7
    _tensor_constant8 = self._tensor_constant8
    _tensor_constant9 = self._tensor_constant9
    int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1);  _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None
    embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32);  int_nbit_split_embedding_codegen_lookup_function = None
    to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu'));  embeddings_cat_empty_rank_handle_inference = None
    keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1);  to_1 = None
    return keyed_tensor

/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning:

The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.

Scripted Graph Module Created!
def forward(self,
    kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
  _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
  _1 = __torch__.torchrec.fx.utils._fx_marker
  _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
  _3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
  flatten_feature_lengths = _0(kjt, )
  _fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
  split = (flatten_feature_lengths).split([2], )
  getitem = split[0]
  to = (getitem).to(torch.device("cuda", 0), True, None, )
  _fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
  _unwrap_kjt = _2(to, )
  getitem_1 = (_unwrap_kjt)[0]
  getitem_2 = (_unwrap_kjt)[1]
  _tensor_constant0 = self._tensor_constant0
  _tensor_constant1 = self._tensor_constant1
  ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)
  _tensor_constant2 = self._tensor_constant2
  _tensor_constant3 = self._tensor_constant3
  _tensor_constant4 = self._tensor_constant4
  _tensor_constant5 = self._tensor_constant5
  _tensor_constant6 = self._tensor_constant6
  _tensor_constant7 = self._tensor_constant7
  _tensor_constant8 = self._tensor_constant8
  _tensor_constant9 = self._tensor_constant9
  int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)
  _4 = [int_nbit_split_embedding_codegen_lookup_function]
  embeddings_cat_empty_rank_handle_inference = _3(_4, 1, "cuda:0", 6, )
  to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
  _5 = ["product", "user"]
  _6 = [64, 64]
  keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
  _7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )
  return keyed_tensor

结论

在本教程中,您已经从训练分布式推荐系统模型开始,一直到使其准备好进行推理。 TorchRec 仓库 中有一个完整的示例,说明如何将 TorchRec TorchScript 模型加载到 C++ 中进行推理。

如需了解更多信息,请参阅我们的 dlrm 示例,该示例包括使用在 Deep Learning Recommendation Model for Personalization and Recommendation Systems 中描述的方法,在 Criteo 1TB 数据集上进行多节点训练。

脚本总运行时间: ( 0 分钟 0.767 秒)

通过 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源