目录

知识蒸馏教程

创建时间:2023年8月22日 | 最后更新时间:2024年7月30日 | 最后验证时间:2024年11月5日

作者: Alexandros Chariton

知识蒸馏是一种技术,能够将大型、计算成本高昂的模型知识转移到较小的模型上,而不会丢失有效性。这使得模型能够在性能较弱的硬件上部署,从而加快评估过程并提高效率。

在这个教程中,我们将运行一系列实验,旨在提高轻量级神经网络的准确性,使用一个更强大的网络作为教师网络。 轻量级网络的计算成本和速度将保持不变,我们的干预仅针对其权重,而不是其前向传播过程。 这项技术的应用可以出现在无人机或手机等设备中。 在这个教程中,我们不使用任何外部包,因为我们需要的一切都包含在 torchtorchvision 中。

在这个教程中,你将学习:

  • 如何修改模型类以提取隐藏表示并用于进一步计算

  • 如何在 PyTorch 中修改常规训练循环,以在例如分类的交叉熵损失之上添加额外的损失

  • 如何通过使用更复杂的模型作为教师来提升轻量级模型的性能

预备知识

  • 1 GPU,4GB 内存

  • PyTorch v2.0 或更高版本

  • CIFAR-10数据集(通过脚本下载并保存在一个名为/data的目录中)

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

加载CIFAR-10

CIFAR-10 是一个包含十类的流行图像数据集。我们的目标是针对每个输入图像预测以下类别中的一个。

../_static/img/cifar10.png

CIFAR-10 图像示例

输入图像为RGB格式,因此它们有3个通道,尺寸为32x32像素。基本上,每张图像由3 x 32 x 32 = 3072个数值描述,这些数值的范围是从0到255。 在神经网络中,一种常见的做法是对输入进行归一化,这是出于多种原因,包括避免常用激活函数中的饱和现象以及提高数值稳定性。 我们的归一化过程包括沿每个通道减去均值并除以标准差。 张量“mean=[0.485, 0.456, 0.406]”和“std=[0.229, 0.224, 0.225]”已经计算完成,它们代表了CIFAR-10预定义子集中用于训练集的每个通道的均值和标准差。 请注意,我们同样在测试集上使用这些值,而无需重新计算均值和标准差。这是因为网络是在通过上述数值进行减法和除法操作后产生的特征上进行训练的,我们希望保持一致性。 此外,在现实生活中,我们无法计算测试集的均值和标准差,因为在我们的假设下,此时这些数据是不可访问的。

作为总结,我们通常将这个保留的数据集称为验证集,并在优化模型在验证集上的性能之后,使用另一个称为测试集的数据集。这样做的目的是为了避免基于单一指标的贪婪且有偏的优化来选择模型。

# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz

  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 492k/170M [00:00<00:34, 4.91MB/s]
  5%|4         | 8.16M/170M [00:00<00:03, 47.1MB/s]
 12%|#1        | 19.9M/170M [00:00<00:01, 79.1MB/s]
 19%|#8        | 31.7M/170M [00:00<00:01, 94.2MB/s]
 25%|##5       | 43.4M/170M [00:00<00:01, 103MB/s]
 32%|###2      | 55.2M/170M [00:00<00:01, 108MB/s]
 39%|###9      | 66.9M/170M [00:00<00:00, 111MB/s]
 46%|####6     | 78.6M/170M [00:00<00:00, 113MB/s]
 53%|#####3    | 90.4M/170M [00:00<00:00, 114MB/s]
 60%|#####9    | 102M/170M [00:01<00:00, 115MB/s]
 67%|######6   | 114M/170M [00:01<00:00, 116MB/s]
 74%|#######3  | 126M/170M [00:01<00:00, 116MB/s]
 81%|########  | 137M/170M [00:01<00:00, 117MB/s]
 88%|########7 | 149M/170M [00:01<00:00, 117MB/s]
 94%|#########4| 161M/170M [00:01<00:00, 117MB/s]
100%|##########| 170M/170M [00:01<00:00, 108MB/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified

注意

此部分仅适用于对快速结果感兴趣的CPU用户。仅在您感兴趣进行小规模实验时使用此选项。请注意,使用任何GPU运行代码应该会相当快速。仅从训练/测试数据集中选择前 num_images_to_keep 张图片

#from torch.utils.data import Subset
#num_images_to_keep = 2000
#train_dataset = Subset(train_dataset, range(min(num_images_to_keep, 50_000)))
#test_dataset = Subset(test_dataset, range(min(num_images_to_keep, 10_000)))
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

定义模型类和实用函数

接下来,我们需要定义我们的模型类。这里需要设置几个用户自定义的参数。我们使用了两种不同的架构,在实验中保持滤波器数量固定,以确保公平的比较。 两种架构都是卷积神经网络(CNNs),它们具有不同数量的卷积层,作为特征提取器,随后接一个10类的分类器。 学生的滤波器数量和神经元数量更少。

# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
    def __init__(self, num_classes=10):
        super(DeepNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
    def __init__(self, num_classes=10):
        super(LightNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

我们采用 2 个函数来帮助我们在原始分类任务上生成和评估结果。 其中一个函数称为 train,它接受以下参数:

  • model: 一个用于训练(更新其权重)的模型实例。

  • train_loader: 我们在上面定义了 train_loader,它的作用是将数据输入模型。

  • epochs: 我们遍历数据集的次数。

  • learning_rate: 学习率决定了我们向收敛方向迈进的步长。过大或过小的步长都可能带来负面影响。

  • device: 决定在哪个设备上运行工作负载。可以根据可用性选择CPU或GPU。

我们的测试函数类似,但会使用 test_loader 来从测试数据集中加载图像。

../_static/img/knowledge_distillation/ce_only.png

使用交叉熵训练两个网络。学生网络将作为基线:

def train(model, train_loader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # inputs: A collection of batch_size images
            # labels: A vector of dimensionality batch_size with integers denoting class of each image
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
            # labels: The actual labels of the images. Vector of dimensionality batch_size
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

def test(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

交叉熵运行

为了可重复性,我们需要设置 torch 的手动种子。我们使用不同的方法训练网络,因此为了公平比较,初始化网络时使用相同的权重是有意义的。首先,使用交叉熵训练教师网络:

torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)

# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
Epoch 1/10, Loss: 1.3366997722164748
Epoch 2/10, Loss: 0.8720864758772009
Epoch 3/10, Loss: 0.6820237918583023
Epoch 4/10, Loss: 0.5375950956893394
Epoch 5/10, Loss: 0.41490358377204223
Epoch 6/10, Loss: 0.3123219789141584
Epoch 7/10, Loss: 0.22101545549185989
Epoch 8/10, Loss: 0.17098309014878615
Epoch 9/10, Loss: 0.13455525941460791
Epoch 10/10, Loss: 0.12078842208208636
Test Accuracy: 75.45%

我们再实例化一个更轻量的网络模型来比较它们的性能。 反向传播对权重初始化敏感, 因此我们需要确保这两个网络具有完全相同的初始化。

torch.manual_seed(42)
new_nn_light = LightNN(num_classes=10).to(device)

为了确保我们已经创建了第一个网络的副本,我们检查其第一层的范数。 如果匹配,那么我们可以安全地得出结论,这两个网络确实相同。

# Print the norm of the first layer of the initial lightweight model
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
# Print the norm of the first layer of the new lightweight model
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())
Norm of 1st layer of nn_light: 2.327361822128296
Norm of 1st layer of new_nn_light: 2.327361822128296

打印每个模型的总参数数量:

total_params_deep = "{:,}".format(sum(p.numel() for p in nn_deep.parameters()))
print(f"DeepNN parameters: {total_params_deep}")
total_params_light = "{:,}".format(sum(p.numel() for p in nn_light.parameters()))
print(f"LightNN parameters: {total_params_light}")
DeepNN parameters: 1,186,986
LightNN parameters: 267,738

使用交叉熵损失训练和测试轻量级网络:

train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_loader, device)
Epoch 1/10, Loss: 1.466049101346594
Epoch 2/10, Loss: 1.1519653670623173
Epoch 3/10, Loss: 1.0232561651398153
Epoch 4/10, Loss: 0.9235453337354733
Epoch 5/10, Loss: 0.8479179534156
Epoch 6/10, Loss: 0.7824301378196462
Epoch 7/10, Loss: 0.7184310383199121
Epoch 8/10, Loss: 0.6588469929707325
Epoch 9/10, Loss: 0.6075488568266945
Epoch 10/10, Loss: 0.556159371533967
Test Accuracy: 70.15%

如我们所见,根据测试准确率,我们现在可以将用作教师的更深网络与我们的轻量级网络(即假设的学生)进行比较。到目前为止,我们的学生尚未对教师进行干预,因此这一性能是由学生自身实现的。 到目前为止,可以通过以下几行查看这些指标:

print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")
Teacher accuracy: 75.45%
Student accuracy: 70.15%

知识蒸馏运行

现在让我们尝试通过引入教师网络来提高学生网络的测试准确率。 知识蒸馏是一种实现这一目标的直接技术, 其基础是两个网络都在我们的类别上输出一个概率分布。 因此,这两个网络具有相同数量的输出神经元。 该方法通过在传统的交叉熵损失中加入一个额外的损失项来实现, 这个损失项基于教师网络的 softmax 输出。 假设一个经过良好训练的教师网络的输出激活值包含额外的信息,这些信息可以在训练过程中被学生网络利用。 原始工作建议,利用软标签中较小概率的比率有助于实现深度神经网络的潜在目标, 即在数据上创建相似性结构,使得相似的对象被映射得更接近。 例如,在 CIFAR-10 中,如果一辆卡车有轮子,它可能会被误认为汽车或飞机, 但不太可能被误认为狗。 因此,可以合理地假设有价值的信息不仅存在于一个良好训练模型的顶部预测中,还存在于整个输出分布中。 然而,仅使用交叉熵并不能充分挖掘这些信息,因为非预测类别的激活值往往非常小, 导致传播的梯度无法有效地改变权重以构建这种理想的向量空间。

在我们继续定义第一个引入教师-学生动态的辅助函数时,我们需要包含几个额外的参数:

  • T: 温度控制输出分布的平滑程度。较大的 T 会导致更平滑的分布,因此较小的概率会得到更大的提升。

  • soft_target_loss_weight: 一个分配给即将包含的额外目标的权重。

  • ce_loss_weight: 交叉熵的权重。调整这些权重会使网络倾向于优化其中一个目标。

../_static/img/knowledge_distillation/distillation_output_loss.png

蒸馏损失是从网络的logits中计算得出的。它仅将梯度返回给学生模型:

def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Forward pass with the student model
            student_logits = student(inputs)

            #Soften the student logits by applying softmax first and log() second
            soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
            soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

            # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
            soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.
train_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)

# Compare the student test accuracy with and without the teacher, after distillation
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
Epoch 1/10, Loss: 2.3962801237545355
Epoch 2/10, Loss: 1.87831118161721
Epoch 3/10, Loss: 1.6540942881113427
Epoch 4/10, Loss: 1.4959803764777415
Epoch 5/10, Loss: 1.367971143454237
Epoch 6/10, Loss: 1.2519247448048019
Epoch 7/10, Loss: 1.1570622474336258
Epoch 8/10, Loss: 1.0719402747995712
Epoch 9/10, Loss: 0.9970421949615869
Epoch 10/10, Loss: 0.9293939061177051
Test Accuracy: 70.75%
Teacher accuracy: 75.45%
Student accuracy without teacher: 70.15%
Student accuracy with CE + KD: 70.75%

余弦损失最小化运行

随意调整控制softmax函数软化程度和损失系数的温度参数。 在神经网络中,很容易将额外的损失函数包含到主要目标中,以实现更好的泛化等目标。 让我们尝试加入一个针对学生的优化目标,但现在我们关注的是它们的隐藏状态,而不是输出层。 我们的目标是通过包含一个简单的损失函数,将信息从教师的表示传递给学生, 随着损失的减少,随后传递给分类器的展平向量变得越来越 相似。 当然,教师不会更新其权重,因此最小化仅依赖于学生的权重。 这种方法背后的原理是,我们假设教师模型具有更好的内部表示, 而学生在没有外部干预的情况下不太可能达到这一点,因此我们人为地推动学生模仿教师的内部表示。 然而,这是否最终有助于学生并不明确,因为推动轻量级网络达到这一点可能是好事, 假设我们找到了导致更好测试准确率的内部表示,但也可能是有害的,因为网络结构不同,学生的学习能力与教师不同。 换句话说,这两个向量(学生的和教师的)没有必要在每个组件上匹配。 学生可以达到一个与教师的内部表示排列不同的表示,并且同样有效。 尽管如此,我们仍然可以运行一个快速实验来了解这种方法的影响。 我们将使用 CosineEmbeddingLoss,它由以下公式给出:

../_static/img/knowledge_distillation/cosine_embedding_loss.png

余弦嵌入损失公式

显然,我们首先需要解决的问题是以下这一点。 当我们对输出层应用蒸馏时,我们提到两个网络的神经元数量相同,等于类别数量。 然而,在我们卷积层之后的层中并非如此。在这里,教师网络在最终卷积层展平后的神经元数量多于学生网络。 我们的损失函数接受两个维度相同的向量作为输入,因此我们需要以某种方式使它们匹配。 我们将通过在教师网络的卷积层后添加一个平均池化层来降低其维度,使其与学生网络的维度一致。

为了继续,我们将修改我们的模型类,或创建新的模型类。 现在,前向函数不仅返回网络的logits,还返回卷积层之后的展平隐藏表示。我们为修改后的教师模型包含上述的池化操作。

class ModifiedDeepNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedDeepNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
        return x, flattened_conv_output_after_pooling

# Create a similar student class where we return a tuple. We do not apply pooling after flattening.
class ModifiedLightNNCosine(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNCosine, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        flattened_conv_output = torch.flatten(x, 1)
        x = self.classifier(flattened_conv_output)
        return x, flattened_conv_output

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)
modified_nn_deep.load_state_dict(nn_deep.state_dict())

# Once again ensure the norm of the first layer is the same for both networks
print("Norm of 1st layer for deep_nn:", torch.norm(nn_deep.features[0].weight).item())
print("Norm of 1st layer for modified_deep_nn:", torch.norm(modified_nn_deep.features[0].weight).item())

# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.
torch.manual_seed(42)
modified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)
print("Norm of 1st layer:", torch.norm(modified_nn_light.features[0].weight).item())
Norm of 1st layer for deep_nn: 7.5062713623046875
Norm of 1st layer for modified_deep_nn: 7.5062713623046875
Norm of 1st layer: 2.327361822128296

显然,我们需要更改训练循环,因为现在模型返回一个元组 (logits, hidden_representation)。使用一个示例输入张量 我们可以打印它们的形状。

# Create a sample input tensor
sample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32

# Pass the input through the student
logits, hidden_representation = modified_nn_light(sample_input)

# Print the shapes of the tensors
print("Student logits shape:", logits.shape) # batch_size x total_classes
print("Student hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size

# Pass the input through the teacher
logits, hidden_representation = modified_nn_deep(sample_input)

# Print the shapes of the tensors
print("Teacher logits shape:", logits.shape) # batch_size x total_classes
print("Teacher hidden representation shape:", hidden_representation.shape) # batch_size x hidden_representation_size
Student logits shape: torch.Size([128, 10])
Student hidden representation shape: torch.Size([128, 1024])
Teacher logits shape: torch.Size([128, 10])
Teacher hidden representation shape: torch.Size([128, 1024])

在我们的情况下,hidden_representation_size1024。这是学生模型最后一层卷积的展平特征图,如你所见, 它是其分类器的输入。教师模型也是如此,因为我们通过 avg_pool1d2048 设置了它。 此处应用的损失仅影响损失计算前学生的权重。换句话说,它不会影响学生的分类器。 修改后的训练循环如下:

../_static/img/knowledge_distillation/cosine_loss_distillation.png

在余弦损失最小化中,我们希望通过将梯度返回给学生来最大化两个表示的余弦相似性:

def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model and keep only the hidden representation
            with torch.no_grad():
                _, teacher_hidden_representation = teacher(inputs)

            # Forward pass with the student model
            student_logits, student_hidden_representation = student(inputs)

            # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.
            hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

我们需要对测试函数进行同样的修改。这里我们忽略模型返回的隐藏表示。

def test_multiple_outputs(model, test_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs, _ = model(inputs) # Disregard the second tensor of the tuple
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

在这种情况下,我们可以轻松地将知识蒸馏和余弦损失最小化包含在同一个函数中。在教师-学生范式中,结合方法以实现更好的性能是很常见的。 对于现在,我们可以运行一个简单的训练-测试会话。

# Train and test the lightweight network with cross entropy loss
train_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)
Epoch 1/10, Loss: 1.3019573956804202
Epoch 2/10, Loss: 1.0648430796230541
Epoch 3/10, Loss: 0.9631246839033063
Epoch 4/10, Loss: 0.8873560082577073
Epoch 5/10, Loss: 0.8324901197877381
Epoch 6/10, Loss: 0.7894727920022462
Epoch 7/10, Loss: 0.7494128462298751
Epoch 8/10, Loss: 0.7148379474649649
Epoch 9/10, Loss: 0.6762727979199051
Epoch 10/10, Loss: 0.6496843107216194
Test Accuracy: 71.17%

中级回归器运行

我们的简单最小化方法并不能保证得到更好的结果,有几个原因,其中之一是向量的维度。 余弦相似度通常在处理高维向量时表现优于欧几里得距离, 但我们处理的是每个向量有1024个组件的情况,因此提取有意义的相似性要困难得多。 此外,正如我们之前提到的,推动教师模型和学生模型的隐藏表示匹配并没有理论支持。 没有充分的理由让我们去追求这两个向量的1:1匹配。 我们将通过引入一个额外的网络——回归器,来提供一个最终的训练干预示例。 目标是首先提取卷积层之后教师模型的特征图, 然后提取卷积层之后学生模型的特征图,最后尝试匹配这些特征图。 然而,这一次,我们将在网络之间引入一个回归器,以促进匹配过程。 这个回归器将是可训练的,并且理想情况下会比我们的简单余弦损失最小化方案表现得更好。 它的主要任务是匹配这些特征图的维度,以便我们能够正确定义教师模型和学生模型之间的损失函数。 定义这样的损失函数提供了一个“教学路径”,这基本上是一个用于反向传播梯度的流程,从而改变学生模型的权重。 专注于我们原始网络中每个分类器之前的卷积层输出,我们有以下形状:

# Pass the sample input only from the convolutional feature extractor
convolutional_fe_output_student = nn_light.features(sample_input)
convolutional_fe_output_teacher = nn_deep.features(sample_input)

# Print their shapes
print("Student's feature extractor output shape: ", convolutional_fe_output_student.shape)
print("Teacher's feature extractor output shape: ", convolutional_fe_output_teacher.shape)
Student's feature extractor output shape:  torch.Size([128, 16, 8, 8])
Teacher's feature extractor output shape:  torch.Size([128, 32, 8, 8])

我们为教师模型有32个滤波器,为学生模型有16个滤波器。 我们将包含一个可训练层,将学生模型的特征图转换为教师模型的特征图形状。 在实践中,我们修改轻量级类以返回中间回归器后的隐藏状态,该中间回归器与卷积特征图和教师模型的尺寸相匹配,而教师模型则返回最终卷积层的输出,不进行池化或展平操作。

../_static/img/knowledge_distillation/fitnets_knowledge_distill.png

可训练层与中间张量的形状相匹配,并且均方误差(MSE)已正确定义:

class ModifiedDeepNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedDeepNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        conv_feature_map = x
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, conv_feature_map

class ModifiedLightNNRegressor(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedLightNNRegressor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # Include an extra regressor (in our case linear)
        self.regressor = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        regressor_output = self.regressor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x, regressor_output

之后,我们需要再次更新我们的训练循环。这次,我们提取学生的回归器输出、教师的特征图, 我们在这两个张量上计算 MSE(它们具有完全相同的形状,因此定义是正确的),并基于该损失进行梯度反向传播, 此外,还要加上分类任务的常规交叉熵损失。

def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):
    ce_loss = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)

    teacher.to(device)
    student.to(device)
    teacher.eval()  # Teacher set to evaluation mode
    student.train() # Student to train mode

    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Again ignore teacher logits
            with torch.no_grad():
                _, teacher_feature_map = teacher(inputs)

            # Forward pass with the student model
            student_logits, regressor_feature_map = student(inputs)

            # Calculate the loss
            hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

            # Calculate the true label loss
            label_loss = ce_loss(student_logits, labels)

            # Weighted sum of the two losses
            loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.

# Initialize a ModifiedLightNNRegressor
torch.manual_seed(42)
modified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)

# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance
modified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)
modified_nn_deep_reg.load_state_dict(nn_deep.state_dict())

# Train and test once again
train_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)
Epoch 1/10, Loss: 1.7312568727966464
Epoch 2/10, Loss: 1.3489013407236474
Epoch 3/10, Loss: 1.2052425062260055
Epoch 4/10, Loss: 1.108028480921255
Epoch 5/10, Loss: 1.028127753673612
Epoch 6/10, Loss: 0.9665506834264301
Epoch 7/10, Loss: 0.9110444729285472
Epoch 8/10, Loss: 0.861484837196672
Epoch 9/10, Loss: 0.8197113380712622
Epoch 10/10, Loss: 0.7801015764246206
Test Accuracy: 71.08%

预期最终的方法会比 CosineLoss 表现更好,因为现在我们允许在教师和学生之间有一个可训练的层, 这为学生在学习过程中提供了一些灵活性,而不是迫使学生复制教师的表示。 包含额外网络的想法是基于提示的蒸馏方法的核心理念。

print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%")
print(f"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%")
Teacher accuracy: 75.45%
Student accuracy without teacher: 70.15%
Student accuracy with CE + KD: 70.75%
Student accuracy with CE + CosineLoss: 71.17%
Student accuracy with CE + RegressorMSE: 71.08%

结论

以上方法都不会增加网络的参数数量或推理时间,因此性能的提升是以在训练过程中计算梯度为代价的。在机器学习应用中,我们主要关注推理时间,因为模型部署之前通常已经完成了训练。如果我们的轻量级模型对于部署来说仍然过于庞大,我们可以应用不同的方法,例如后训练量化。额外的损失函数可以应用于许多任务,而不仅仅是分类任务,您可以尝试诸如系数、温度或神经元数量等参数。您可以随意调整上述教程中的任何数值,但请注意,如果您更改了神经元或滤波器的数量,很可能会出现形状不匹配的问题。

如需更多信息,请参见:

脚本总运行时间: ( 7 分钟 45.085 秒)

通过 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源