目录

导出 TorchRL 模块

作者Vincent Moens

注意

要在笔记本中运行本教程,请添加安装单元 开头包含:

!pip install tensordict
!pip install torchrl
!pip install "gymnasium[atari,accept-rom-license]"<1.0.0

介绍

如果策略无法在实际环境中部署,则学习该策略就没有什么价值。 如其他教程所示,TorchRL 非常注重模块化和可组合性:多亏了 , 该库的组件可以用最通用的方式编写,方法是将它们的签名抽象为 仅对 input 的一组作。 这可能会给人一种印象,即该库必然仅用于训练,就像典型的低级执行一样 硬件(边缘设备、机器人、Arduino、Raspberry Pi)不执行 Python 代码,更不用说 PyTorch、Tensordict 了 或 torchrl 安装。tensordictTensorDict

幸运的是,PyTorch 提供了一个完整的解决方案生态系统,用于将代码和经过训练的模型导出到设备和 硬件,并且 TorchRL 完全有能力与之交互。 可以从一组不同的后端中进行选择,包括本教程中介绍的 ONNX 或 AOTInductor。 本教程简要概述了如何隔离经过训练的模型并将其作为独立可执行文件提供 以在硬件上导出。

主要学习内容:

  • 训练后导出任何 TorchRL 模块;

  • 使用各种后端;

  • 测试导出的模型。

快速回顾:一个简单的 TorchRL 训练循环

在本节中,我们将重现上一个入门教程中的训练循环,略微适应使用 与 Atari 游戏,因为它们是由 gymnasium 库渲染的。 我们将坚持使用 DQN 示例,并展示如何使用输出值分布的策略 后。

import time
from pathlib import Path

import numpy as np

import torch

from tensordict.nn import (
    TensorDictModule as Mod,
    TensorDictSequential,
    TensorDictSequential as Seq,
)

from torch.optim import Adam

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

from torchrl.envs import (
    Compose,
    GrayScale,
    GymEnv,
    Resize,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

from torchrl.modules import ConvNet, EGreedyModule, QValueModule

from torchrl.objectives import DQNLoss, SoftUpdate

torch.manual_seed(0)

env = TransformedEnv(
    GymEnv("ALE/Pong-v5", categorical_action_encoding=True),
    Compose(
        ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter()
    ),
)
env.set_seed(0)

value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n)
value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_explore = Seq(policy, exploration_module)

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    env,
    policy_explore,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters())
updater = SoftUpdate(loss, eps=0.99)

total_count = 0
total_episodes = 0
t0 = time.time()
for data in collector:
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            exploration_module.step(data.numel())
            # Update target params
            updater.step()
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:
        break

导出基于 TensorDictModule 的策略

TensorDict使我们能够构建具有极大灵活性的策略:从常规的Module那 output action values from a observation,我们添加了一个模块,该 读取这些值并使用一些启发式方法(例如,argmax 调用)计算一个 action。QValueModule

但是,在我们的案例中有一个小的技术问题:环境(实际的 Atari 游戏)没有返回 灰度、84x84 图像,但原始屏幕大小的彩色图像。我们附加到环境中的转换可确保 模型可以读取图像。我们可以看到,从训练的角度来看,环境之间的边界 模型是模糊的,但在执行时事情要清晰得多:模型应该负责转换 将输入数据 (图像) 转换为可由 CNN 处理的格式。

在这里,tensordict 的魔力将再次解除我们的阻塞:碰巧大多数本地(非递归)TorchRL 的 转换既可以用作环境转换,也可以用作Module实例。让我们看看如何将它们添加到我们的策略中:

policy_transform = TensorDictSequential(
    env.transform[
        :-1
    ],  # the last transform is a step counter which we don't need for preproc
    policy_explore.requires_grad_(
        False
    ),  # Using the explorative version of the policy for didactic purposes, see below.
)

我们创建一个虚构的输入,并将其传递给export()与策略。这将得到一个 “原始” python 函数,该函数将读取我们的输入张量并输出一个作,而无需引用 TorchRL 或 tensordict 模块。

一个好的做法是打电话让模型知道 我们只想要一组特定的输出(如果策略返回多个张量)。select_out_keys()

fake_td = env.base_env.fake_tensordict()
pixels = fake_td["pixels"]
with set_exploration_type("DETERMINISTIC"):
    exported_policy = torch.export.export(
        # Select only the "action" output key
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

表示策略可能非常有洞察力:我们可以看到第一个作是 permute、div、unsqueeze、 resize,然后是卷积层和 MLP 层。

print("Deterministic policy")
exported_policy.graph_module.print_readable()
Deterministic policy
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):
         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)
        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None
        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None
        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None
        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None
        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None
        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None
        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None
        getitem: "f32[84, 84]" = unbind[0]
        getitem_1: "f32[84, 84]" = unbind[1]
        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None
        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None
        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None
        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None
        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None
        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None
        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None

         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)
        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)
        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)
        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)
        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)
        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None
        return (to_2,)


'class GraphModule(torch.nn.Module):\n    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):\n         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)\n        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None\n        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None\n        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None\n        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None\n        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None\n        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None\n        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None\n        getitem: "f32[84, 84]" = unbind[0]\n        getitem_1: "f32[84, 84]" = unbind[1]\n        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None\n        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None\n        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None\n        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None\n        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None\n        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None\n        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None\n        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None\n        \n         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)\n        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)\n        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)\n        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)\n        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)\n        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None\n        return (to_2,)\n        '

作为最终检查,我们可以使用虚拟输入来执行策略。输出(对于单个图像)应为整数 从 0 到 6,表示要在游戏中执行的作。

output = exported_policy.module()(pixels=pixels)
print("Exported module output", output)
Exported module output tensor(1)

有关导出实例的更多详细信息,请参阅 tensordict 文档TensorDictModule

注意

导出接受和输出嵌套键的模块是完全可以的。 对应的 kwargs 将是 key 的 “_”.join(key) 版本,即 (“group0”, “agent0”, “obs”) key 将对应于 “group0_agent0_obs” 关键字参数。键冲突(例如,(“group0_agent0”、“obs”)(“group0”、“agent0_obs”)可能会导致未定义的行为,应不惜一切代价避免。 显然,键名也应该始终产生有效的关键字参数,即它们不应该包含特殊的 字符,例如空格或逗号。

torch.export还有许多其他功能,我们将在下面进一步探讨。在此之前,我们先做一个小的 关于测试时推理上下文中的探索和随机策略的题外话,以及递归 政策。

使用随机策略

你可能已经注意到了,上面我们使用了上下文管理器来控制 策略的行为。如果策略是随机的(例如,策略输出对作 空间(就像 PPO 或其他类似的策略算法中的情况一样)或 explorative(使用探索模块 附加如 E-Greedy、加性高斯或 Ornstein-Uhlenbeck),我们可能想要或不想使用该探索 策略。 幸运的是,export utils 可以理解上下文管理器,并且只要 export 发生在正确的 Context Manager 中,策略的行为应与指示的内容相匹配。为了证明这一点,让我们尝试使用 另一种探索类型:set_exploration_type

with set_exploration_type("RANDOM"):
    exported_stochastic_policy = torch.export.export(
        policy_transform.select_out_keys("action"),
        args=(),
        kwargs={"pixels": pixels},
        strict=False,
    )

与以前的版本不同,我们导出的策略现在应该在调用堆栈的末尾有一个 random 模块。 确实,最后三个作是:生成一个介于 0 和 6 之间的随机整数,使用随机掩码并选择 Network 输出或基于 Mask 中值的 Random作。

print("Stochastic policy")
exported_stochastic_policy.graph_module.print_readable()
Stochastic policy
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):
         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)
        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None
        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None
        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None
        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None
        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None
        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None
        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None
        getitem: "f32[84, 84]" = unbind[0]
        getitem_1: "f32[84, 84]" = unbind[1]
        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None
        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None
        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None
        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None
        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None
        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None
        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None

         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)
        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)
        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)
        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)
        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)
        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)
        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:167 in forward, code: cond = torch.rand(action_tensordict.shape, device=out.device) < eps
        rand: "f32[]" = torch.ops.aten.rand.default([], device = device(type='cpu'), pin_memory = False)
        lt: "b8[]" = torch.ops.aten.lt.Tensor(rand, b_module_1_module_1_eps);  rand = b_module_1_module_1_eps = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:169 in forward, code: cond = expand_as_right(cond, out)
        expand: "b8[]" = torch.ops.aten.expand.default(lt, []);  lt = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:193 in forward, code: r = spec.rand()
        randint: "i64[]" = torch.ops.aten.randint.low(0, 6, [], device = device(type='cpu'), pin_memory = False)

         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:196 in forward, code: out = torch.where(cond, r, out)
        where: "i64[]" = torch.ops.aten.where.self(expand, randint, to_2);  expand = randint = to_2 = None
        return (where,)


'class GraphModule(torch.nn.Module):\n    def forward(self, p_module_1_module_0_module_0_module_0_0_weight: "f32[32, 1, 8, 8]", p_module_1_module_0_module_0_module_0_0_bias: "f32[32]", p_module_1_module_0_module_0_module_0_2_weight: "f32[64, 32, 4, 4]", p_module_1_module_0_module_0_module_0_2_bias: "f32[64]", p_module_1_module_0_module_0_module_0_4_weight: "f32[64, 64, 3, 3]", p_module_1_module_0_module_0_module_0_4_bias: "f32[64]", p_module_1_module_0_module_0_module_1_0_weight: "f32[512, 3136]", p_module_1_module_0_module_0_module_1_0_bias: "f32[512]", p_module_1_module_0_module_0_module_1_2_weight: "f32[6, 512]", p_module_1_module_0_module_0_module_1_2_bias: "f32[6]", b_module_1_module_1_eps_init: "f32[]", b_module_1_module_1_eps_end: "f32[]", b_module_1_module_1_eps: "f32[]", kwargs_pixels: "u8[210, 160, 3]"):\n         # File: /pytorch/rl/torchrl/envs/transforms/transforms.py:308 in forward, code: data = self._apply_transform(data)\n        permute: "u8[3, 210, 160]" = torch.ops.aten.permute.default(kwargs_pixels, [-1, -3, -2]);  kwargs_pixels = None\n        div: "f32[3, 210, 160]" = torch.ops.aten.div.Tensor(permute, 255);  permute = None\n        to: "f32[3, 210, 160]" = torch.ops.aten.to.dtype(div, torch.float32);  div = None\n        unsqueeze: "f32[1, 3, 210, 160]" = torch.ops.aten.unsqueeze.default(to, 0);  to = None\n        upsample_nearest2d: "f32[1, 3, 84, 84]" = torch.ops.aten.upsample_nearest2d.vec(unsqueeze, [84, 84], None);  unsqueeze = None\n        squeeze: "f32[3, 84, 84]" = torch.ops.aten.squeeze.dim(upsample_nearest2d, 0);  upsample_nearest2d = None\n        unbind = torch.ops.aten.unbind.int(squeeze, -3);  squeeze = None\n        getitem: "f32[84, 84]" = unbind[0]\n        getitem_1: "f32[84, 84]" = unbind[1]\n        getitem_2: "f32[84, 84]" = unbind[2];  unbind = None\n        mul: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem, 0.2989);  getitem = None\n        mul_1: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_1, 0.587);  getitem_1 = None\n        add: "f32[84, 84]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None\n        mul_2: "f32[84, 84]" = torch.ops.aten.mul.Tensor(getitem_2, 0.114);  getitem_2 = None\n        add_1: "f32[84, 84]" = torch.ops.aten.add.Tensor(add, mul_2);  add = mul_2 = None\n        to_1: "f32[84, 84]" = torch.ops.aten.to.dtype(add_1, torch.float32);  add_1 = None\n        unsqueeze_1: "f32[1, 84, 84]" = torch.ops.aten.unsqueeze.default(to_1, -3);  to_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d: "f32[32, 20, 20]" = torch.ops.aten.conv2d.default(unsqueeze_1, p_module_1_module_0_module_0_module_0_0_weight, p_module_1_module_0_module_0_module_0_0_bias, [4, 4]);  unsqueeze_1 = p_module_1_module_0_module_0_module_0_0_weight = p_module_1_module_0_module_0_module_0_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu: "f32[32, 20, 20]" = torch.ops.aten.relu.default(conv2d);  conv2d = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_1: "f32[64, 9, 9]" = torch.ops.aten.conv2d.default(relu, p_module_1_module_0_module_0_module_0_2_weight, p_module_1_module_0_module_0_module_0_2_bias, [2, 2]);  relu = p_module_1_module_0_module_0_module_0_2_weight = p_module_1_module_0_module_0_module_0_2_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_1: "f32[64, 9, 9]" = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/conv.py:554 in forward, code: return self._conv_forward(input, self.weight, self.bias)\n        conv2d_2: "f32[64, 7, 7]" = torch.ops.aten.conv2d.default(relu_1, p_module_1_module_0_module_0_module_0_4_weight, p_module_1_module_0_module_0_module_0_4_bias);  relu_1 = p_module_1_module_0_module_0_module_0_4_weight = p_module_1_module_0_module_0_module_0_4_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_2: "f32[64, 7, 7]" = torch.ops.aten.relu.default(conv2d_2);  conv2d_2 = None\n        \n         # File: /pytorch/rl/torchrl/modules/models/utils.py:86 in forward, code: value = value.flatten(-self.ndims_in, -1)\n        flatten: "f32[3136]" = torch.ops.aten.flatten.using_ints(relu_2, -3);  relu_2 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear: "f32[512]" = torch.ops.aten.linear.default(flatten, p_module_1_module_0_module_0_module_1_0_weight, p_module_1_module_0_module_0_module_1_0_bias);  flatten = p_module_1_module_0_module_0_module_1_0_weight = p_module_1_module_0_module_0_module_1_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:133 in forward, code: return F.relu(input, inplace=self.inplace)\n        relu_3: "f32[512]" = torch.ops.aten.relu.default(linear);  linear = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_1: "f32[6]" = torch.ops.aten.linear.default(relu_3, p_module_1_module_0_module_0_module_1_2_weight, p_module_1_module_0_module_0_module_1_2_bias);  relu_3 = p_module_1_module_0_module_0_module_1_2_weight = p_module_1_module_0_module_0_module_1_2_bias = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:616 in forward, code: action = self.action_func_mapping[self.action_space](action_values)\n        argmax: "i64[]" = torch.ops.aten.argmax.default(linear_1, -1)\n        to_2: "i64[]" = torch.ops.aten.to.dtype(argmax, torch.int64);  argmax = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/actors.py:621 in forward, code: chosen_action_value = action_value_func(action_values, action)\n        unsqueeze_2: "i64[1]" = torch.ops.aten.unsqueeze.default(to_2, -1)\n        gather: "f32[1]" = torch.ops.aten.gather.default(linear_1, -1, unsqueeze_2);  linear_1 = unsqueeze_2 = gather = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:167 in forward, code: cond = torch.rand(action_tensordict.shape, device=out.device) < eps\n        rand: "f32[]" = torch.ops.aten.rand.default([], device = device(type=\'cpu\'), pin_memory = False)\n        lt: "b8[]" = torch.ops.aten.lt.Tensor(rand, b_module_1_module_1_eps);  rand = b_module_1_module_1_eps = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:169 in forward, code: cond = expand_as_right(cond, out)\n        expand: "b8[]" = torch.ops.aten.expand.default(lt, []);  lt = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:193 in forward, code: r = spec.rand()\n        randint: "i64[]" = torch.ops.aten.randint.low(0, 6, [], device = device(type=\'cpu\'), pin_memory = False)\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/exploration.py:196 in forward, code: out = torch.where(cond, r, out)\n        where: "i64[]" = torch.ops.aten.where.self(expand, randint, to_2);  expand = randint = to_2 = None\n        return (where,)\n        '

使用循环策略

另一个典型的用例是 recurrent policy,它将输出一个 action 以及一个或多个 recurrent state。 LSTM 和 GRU 是基于 CuDNN 的模块,这意味着它们的行为与常规模块不同Module实例(export utils 可能无法很好地跟踪它们)。幸运的是,TorchRL 提供了一个 python 这些模块的实现,可在需要时与 CuDNN 版本交换。

为了说明这一点,让我们编写一个依赖于 RNN 的原型策略:

from tensordict.nn import TensorDictModule
from torchrl.envs import BatchSizeTransform
from torchrl.modules import LSTMModule, MLP

lstm = LSTMModule(
    input_size=32,
    num_layers=2,
    hidden_size=256,
    in_keys=["observation", "hidden0", "hidden1"],
    out_keys=["intermediate", "hidden0", "hidden1"],
)

如果 LSTM 模块不是基于 python 的,而是 CuDNN (LSTM)、make_python_based()method 可用于使用 Python 版本。

lstm = lstm.make_python_based()

现在让我们创建策略。我们将修改输入形状的两层组合在一起(解压/挤压作) 使用 LSTM 和 MLP。

recurrent_policy = TensorDictSequential(
    # Unsqueeze the first dim of all tensors to make LSTMCell happy
    BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)),
    lstm,
    TensorDictModule(
        MLP(in_features=256, out_features=5, num_cells=[64, 64]),
        in_keys=["intermediate"],
        out_keys=["action"],
    ),
    # Squeeze the first dim of all tensors to get the original shape back
    BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)),
)

和以前一样,我们选择相关的键:

recurrent_policy.select_out_keys("action", "hidden0", "hidden1")
print("recurrent policy input keys:", recurrent_policy.in_keys)
print("recurrent policy output keys:", recurrent_policy.out_keys)
recurrent policy input keys: ['observation', 'hidden0', 'hidden1', 'is_init']
recurrent policy output keys: ['action', 'hidden0', 'hidden1']

现在,我们可以导出了。为此,我们构建了 false input 并将它们传递给export():

fake_obs = torch.randn(32)
fake_hidden0 = torch.randn(2, 256)
fake_hidden1 = torch.randn(2, 256)

# Tensor indicating whether the state is the first of a sequence
fake_is_init = torch.zeros((), dtype=torch.bool)

exported_recurrent_policy = torch.export.export(
    recurrent_policy,
    args=(),
    kwargs={
        "observation": fake_obs,
        "hidden0": fake_hidden0,
        "hidden1": fake_hidden1,
        "is_init": fake_is_init,
    },
    strict=False,
)
print("Recurrent policy graph:")
exported_recurrent_policy.graph_module.print_readable()
Recurrent policy graph:
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_lstm_weight_ih_l0: "f32[1024, 32]", p_module_1_lstm_weight_hh_l0: "f32[1024, 256]", p_module_1_lstm_bias_ih_l0: "f32[1024]", p_module_1_lstm_bias_hh_l0: "f32[1024]", p_module_1_lstm_weight_ih_l1: "f32[1024, 256]", p_module_1_lstm_weight_hh_l1: "f32[1024, 256]", p_module_1_lstm_bias_ih_l1: "f32[1024]", p_module_1_lstm_bias_hh_l1: "f32[1024]", p_module_2_module_0_weight: "f32[64, 256]", p_module_2_module_0_bias: "f32[64]", p_module_2_module_2_weight: "f32[64, 64]", p_module_2_module_2_bias: "f32[64]", p_module_2_module_4_weight: "f32[5, 64]", p_module_2_module_4_bias: "f32[5]", kwargs_observation: "f32[32]", kwargs_hidden0: "f32[2, 256]", kwargs_hidden1: "f32[2, 256]", kwargs_is_init: "b8[]"):
         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
        unsqueeze: "f32[1, 32]" = torch.ops.aten.unsqueeze.default(kwargs_observation, 0);  kwargs_observation = None
        unsqueeze_1: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden0, 0);  kwargs_hidden0 = None
        unsqueeze_2: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden1, 0);  kwargs_hidden1 = None
        unsqueeze_3: "b8[1]" = torch.ops.aten.unsqueeze.default(kwargs_is_init, 0);  kwargs_is_init = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:743 in forward, code: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
        unsqueeze_4: "f32[1, 1, 32]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
        unsqueeze_5: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
        unsqueeze_6: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 1);  unsqueeze_2 = None
        unsqueeze_7: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 1);  unsqueeze_3 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:745 in forward, code: is_init = tensordict_shaped["is_init"].squeeze(-1)
        squeeze: "b8[1]" = torch.ops.aten.squeeze.dim(unsqueeze_7, -1)

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:772 in forward, code: is_init_expand = expand_as_right(is_init, hidden0)
        unsqueeze_8: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(squeeze, -1);  squeeze = None
        unsqueeze_9: "b8[1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, -1);  unsqueeze_8 = None
        unsqueeze_10: "b8[1, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_9, -1);  unsqueeze_9 = None
        expand: "b8[1, 1, 2, 256]" = torch.ops.aten.expand.default(unsqueeze_10, [1, 1, 2, 256]);  unsqueeze_10 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:773 in forward, code: hidden0 = torch.where(is_init_expand, 0, hidden0)
        where: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_5);  unsqueeze_5 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:774 in forward, code: hidden1 = torch.where(is_init_expand, 0, hidden1)
        where_1: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_6);  expand = unsqueeze_6 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(
        slice_1: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where, 0, 0, 9223372036854775807);  where = None
        select: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_1, 1, 0);  slice_1 = None
        slice_2: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where_1, 0, 0, 9223372036854775807);  where_1 = None
        select_1: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_2, 1, 0);  slice_2 = None
        transpose: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select, -3, -2);  select = None
        transpose_1: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select_1, -3, -2);  select_1 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:317 in forward, code: return self._lstm(input, hx)
        unbind = torch.ops.aten.unbind.int(transpose);  transpose = None
        getitem: "f32[1, 256]" = unbind[0]
        getitem_1: "f32[1, 256]" = unbind[1];  unbind = None
        unbind_1 = torch.ops.aten.unbind.int(transpose_1);  transpose_1 = None
        getitem_2: "f32[1, 256]" = unbind_1[0]
        getitem_3: "f32[1, 256]" = unbind_1[1];  unbind_1 = None
        unbind_2 = torch.ops.aten.unbind.int(unsqueeze_4, 1)
        getitem_4: "f32[1, 32]" = unbind_2[0];  unbind_2 = None
        linear: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_4, p_module_1_lstm_weight_ih_l0, p_module_1_lstm_bias_ih_l0);  getitem_4 = p_module_1_lstm_weight_ih_l0 = p_module_1_lstm_bias_ih_l0 = None
        linear_1: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem, p_module_1_lstm_weight_hh_l0, p_module_1_lstm_bias_hh_l0);  getitem = p_module_1_lstm_weight_hh_l0 = p_module_1_lstm_bias_hh_l0 = None
        add: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear, linear_1);  linear = linear_1 = None
        chunk = torch.ops.aten.chunk.default(add, 4, 1);  add = None
        getitem_5: "f32[1, 256]" = chunk[0]
        getitem_6: "f32[1, 256]" = chunk[1]
        getitem_7: "f32[1, 256]" = chunk[2]
        getitem_8: "f32[1, 256]" = chunk[3];  chunk = None
        sigmoid: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_5);  getitem_5 = None
        sigmoid_1: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_6);  getitem_6 = None
        tanh: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_7);  getitem_7 = None
        sigmoid_2: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_8);  getitem_8 = None
        mul: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_2, sigmoid_1);  getitem_2 = sigmoid_1 = None
        mul_1: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid, tanh);  sigmoid = tanh = None
        add_1: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        tanh_1: "f32[1, 256]" = torch.ops.aten.tanh.default(add_1)
        mul_2: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_2, tanh_1);  sigmoid_2 = tanh_1 = None
        linear_2: "f32[1, 1024]" = torch.ops.aten.linear.default(mul_2, p_module_1_lstm_weight_ih_l1, p_module_1_lstm_bias_ih_l1);  p_module_1_lstm_weight_ih_l1 = p_module_1_lstm_bias_ih_l1 = None
        linear_3: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_1, p_module_1_lstm_weight_hh_l1, p_module_1_lstm_bias_hh_l1);  getitem_1 = p_module_1_lstm_weight_hh_l1 = p_module_1_lstm_bias_hh_l1 = None
        add_2: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear_2, linear_3);  linear_2 = linear_3 = None
        chunk_1 = torch.ops.aten.chunk.default(add_2, 4, 1);  add_2 = None
        getitem_9: "f32[1, 256]" = chunk_1[0]
        getitem_10: "f32[1, 256]" = chunk_1[1]
        getitem_11: "f32[1, 256]" = chunk_1[2]
        getitem_12: "f32[1, 256]" = chunk_1[3];  chunk_1 = None
        sigmoid_3: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_9);  getitem_9 = None
        sigmoid_4: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_10);  getitem_10 = None
        tanh_2: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_11);  getitem_11 = None
        sigmoid_5: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_12);  getitem_12 = None
        mul_3: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_3, sigmoid_4);  getitem_3 = sigmoid_4 = None
        mul_4: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_3, tanh_2);  sigmoid_3 = tanh_2 = None
        add_3: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None
        tanh_3: "f32[1, 256]" = torch.ops.aten.tanh.default(add_3)
        mul_5: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_5, tanh_3);  sigmoid_5 = tanh_3 = None
        stack: "f32[1, 1, 256]" = torch.ops.aten.stack.default([mul_5], 1)
        stack_1: "f32[2, 1, 256]" = torch.ops.aten.stack.default([mul_2, mul_5]);  mul_2 = mul_5 = None
        stack_2: "f32[2, 1, 256]" = torch.ops.aten.stack.default([add_1, add_3]);  add_1 = add_3 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(
        transpose_2: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_1, 0, 1);  stack_1 = None
        transpose_3: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_2, 0, 1);  stack_2 = None
        stack_3: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_2], 1);  transpose_2 = None
        stack_4: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_3], 1);  transpose_3 = None

         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:788 in forward, code: tensordict.update(tensordict_shaped.reshape(shape))
        reshape: "f32[1, 32]" = torch.ops.aten.reshape.default(unsqueeze_4, [1, 32]);  unsqueeze_4 = None
        reshape_1: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_3, [1, 2, 256]);  stack_3 = None
        reshape_2: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_4, [1, 2, 256]);  stack_4 = None
        reshape_3: "b8[1]" = torch.ops.aten.reshape.default(unsqueeze_7, [1]);  unsqueeze_7 = None
        reshape_4: "f32[1, 256]" = torch.ops.aten.reshape.default(stack, [1, 256]);  stack = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_4: "f32[1, 64]" = torch.ops.aten.linear.default(reshape_4, p_module_2_module_0_weight, p_module_2_module_0_bias);  p_module_2_module_0_weight = p_module_2_module_0_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)
        tanh_4: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_4);  linear_4 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_5: "f32[1, 64]" = torch.ops.aten.linear.default(tanh_4, p_module_2_module_2_weight, p_module_2_module_2_bias);  tanh_4 = p_module_2_module_2_weight = p_module_2_module_2_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)
        tanh_5: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_5);  linear_5 = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear_6: "f32[1, 5]" = torch.ops.aten.linear.default(tanh_5, p_module_2_module_4_weight, p_module_2_module_4_bias);  tanh_5 = p_module_2_module_4_weight = p_module_2_module_4_bias = None

         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
        squeeze_1: "f32[32]" = torch.ops.aten.squeeze.dim(reshape, 0);  reshape = squeeze_1 = None
        squeeze_2: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_1, 0);  reshape_1 = None
        squeeze_3: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_2, 0);  reshape_2 = None
        squeeze_4: "b8[]" = torch.ops.aten.squeeze.dim(reshape_3, 0);  reshape_3 = squeeze_4 = None
        squeeze_5: "f32[256]" = torch.ops.aten.squeeze.dim(reshape_4, 0);  reshape_4 = squeeze_5 = None
        squeeze_6: "f32[5]" = torch.ops.aten.squeeze.dim(linear_6, 0);  linear_6 = None
        return (squeeze_6, squeeze_2, squeeze_3)


'class GraphModule(torch.nn.Module):\n    def forward(self, p_module_1_lstm_weight_ih_l0: "f32[1024, 32]", p_module_1_lstm_weight_hh_l0: "f32[1024, 256]", p_module_1_lstm_bias_ih_l0: "f32[1024]", p_module_1_lstm_bias_hh_l0: "f32[1024]", p_module_1_lstm_weight_ih_l1: "f32[1024, 256]", p_module_1_lstm_weight_hh_l1: "f32[1024, 256]", p_module_1_lstm_bias_ih_l1: "f32[1024]", p_module_1_lstm_bias_hh_l1: "f32[1024]", p_module_2_module_0_weight: "f32[64, 256]", p_module_2_module_0_bias: "f32[64]", p_module_2_module_2_weight: "f32[64, 64]", p_module_2_module_2_bias: "f32[64]", p_module_2_module_4_weight: "f32[5, 64]", p_module_2_module_4_bias: "f32[5]", kwargs_observation: "f32[32]", kwargs_hidden0: "f32[2, 256]", kwargs_hidden1: "f32[2, 256]", kwargs_is_init: "b8[]"):\n         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)\n        unsqueeze: "f32[1, 32]" = torch.ops.aten.unsqueeze.default(kwargs_observation, 0);  kwargs_observation = None\n        unsqueeze_1: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden0, 0);  kwargs_hidden0 = None\n        unsqueeze_2: "f32[1, 2, 256]" = torch.ops.aten.unsqueeze.default(kwargs_hidden1, 0);  kwargs_hidden1 = None\n        unsqueeze_3: "b8[1]" = torch.ops.aten.unsqueeze.default(kwargs_is_init, 0);  kwargs_is_init = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:743 in forward, code: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)\n        unsqueeze_4: "f32[1, 1, 32]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None\n        unsqueeze_5: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None\n        unsqueeze_6: "f32[1, 1, 2, 256]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 1);  unsqueeze_2 = None\n        unsqueeze_7: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 1);  unsqueeze_3 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:745 in forward, code: is_init = tensordict_shaped["is_init"].squeeze(-1)\n        squeeze: "b8[1]" = torch.ops.aten.squeeze.dim(unsqueeze_7, -1)\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:772 in forward, code: is_init_expand = expand_as_right(is_init, hidden0)\n        unsqueeze_8: "b8[1, 1]" = torch.ops.aten.unsqueeze.default(squeeze, -1);  squeeze = None\n        unsqueeze_9: "b8[1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, -1);  unsqueeze_8 = None\n        unsqueeze_10: "b8[1, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_9, -1);  unsqueeze_9 = None\n        expand: "b8[1, 1, 2, 256]" = torch.ops.aten.expand.default(unsqueeze_10, [1, 1, 2, 256]);  unsqueeze_10 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:773 in forward, code: hidden0 = torch.where(is_init_expand, 0, hidden0)\n        where: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_5);  unsqueeze_5 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:774 in forward, code: hidden1 = torch.where(is_init_expand, 0, hidden1)\n        where_1: "f32[1, 1, 2, 256]" = torch.ops.aten.where.ScalarSelf(expand, 0, unsqueeze_6);  expand = unsqueeze_6 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(\n        slice_1: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where, 0, 0, 9223372036854775807);  where = None\n        select: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_1, 1, 0);  slice_1 = None\n        slice_2: "f32[1, 1, 2, 256]" = torch.ops.aten.slice.Tensor(where_1, 0, 0, 9223372036854775807);  where_1 = None\n        select_1: "f32[1, 2, 256]" = torch.ops.aten.select.int(slice_2, 1, 0);  slice_2 = None\n        transpose: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select, -3, -2);  select = None\n        transpose_1: "f32[2, 1, 256]" = torch.ops.aten.transpose.int(select_1, -3, -2);  select_1 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:317 in forward, code: return self._lstm(input, hx)\n        unbind = torch.ops.aten.unbind.int(transpose);  transpose = None\n        getitem: "f32[1, 256]" = unbind[0]\n        getitem_1: "f32[1, 256]" = unbind[1];  unbind = None\n        unbind_1 = torch.ops.aten.unbind.int(transpose_1);  transpose_1 = None\n        getitem_2: "f32[1, 256]" = unbind_1[0]\n        getitem_3: "f32[1, 256]" = unbind_1[1];  unbind_1 = None\n        unbind_2 = torch.ops.aten.unbind.int(unsqueeze_4, 1)\n        getitem_4: "f32[1, 32]" = unbind_2[0];  unbind_2 = None\n        linear: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_4, p_module_1_lstm_weight_ih_l0, p_module_1_lstm_bias_ih_l0);  getitem_4 = p_module_1_lstm_weight_ih_l0 = p_module_1_lstm_bias_ih_l0 = None\n        linear_1: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem, p_module_1_lstm_weight_hh_l0, p_module_1_lstm_bias_hh_l0);  getitem = p_module_1_lstm_weight_hh_l0 = p_module_1_lstm_bias_hh_l0 = None\n        add: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear, linear_1);  linear = linear_1 = None\n        chunk = torch.ops.aten.chunk.default(add, 4, 1);  add = None\n        getitem_5: "f32[1, 256]" = chunk[0]\n        getitem_6: "f32[1, 256]" = chunk[1]\n        getitem_7: "f32[1, 256]" = chunk[2]\n        getitem_8: "f32[1, 256]" = chunk[3];  chunk = None\n        sigmoid: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_5);  getitem_5 = None\n        sigmoid_1: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_6);  getitem_6 = None\n        tanh: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_7);  getitem_7 = None\n        sigmoid_2: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_8);  getitem_8 = None\n        mul: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_2, sigmoid_1);  getitem_2 = sigmoid_1 = None\n        mul_1: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid, tanh);  sigmoid = tanh = None\n        add_1: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None\n        tanh_1: "f32[1, 256]" = torch.ops.aten.tanh.default(add_1)\n        mul_2: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_2, tanh_1);  sigmoid_2 = tanh_1 = None\n        linear_2: "f32[1, 1024]" = torch.ops.aten.linear.default(mul_2, p_module_1_lstm_weight_ih_l1, p_module_1_lstm_bias_ih_l1);  p_module_1_lstm_weight_ih_l1 = p_module_1_lstm_bias_ih_l1 = None\n        linear_3: "f32[1, 1024]" = torch.ops.aten.linear.default(getitem_1, p_module_1_lstm_weight_hh_l1, p_module_1_lstm_bias_hh_l1);  getitem_1 = p_module_1_lstm_weight_hh_l1 = p_module_1_lstm_bias_hh_l1 = None\n        add_2: "f32[1, 1024]" = torch.ops.aten.add.Tensor(linear_2, linear_3);  linear_2 = linear_3 = None\n        chunk_1 = torch.ops.aten.chunk.default(add_2, 4, 1);  add_2 = None\n        getitem_9: "f32[1, 256]" = chunk_1[0]\n        getitem_10: "f32[1, 256]" = chunk_1[1]\n        getitem_11: "f32[1, 256]" = chunk_1[2]\n        getitem_12: "f32[1, 256]" = chunk_1[3];  chunk_1 = None\n        sigmoid_3: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_9);  getitem_9 = None\n        sigmoid_4: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_10);  getitem_10 = None\n        tanh_2: "f32[1, 256]" = torch.ops.aten.tanh.default(getitem_11);  getitem_11 = None\n        sigmoid_5: "f32[1, 256]" = torch.ops.aten.sigmoid.default(getitem_12);  getitem_12 = None\n        mul_3: "f32[1, 256]" = torch.ops.aten.mul.Tensor(getitem_3, sigmoid_4);  getitem_3 = sigmoid_4 = None\n        mul_4: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_3, tanh_2);  sigmoid_3 = tanh_2 = None\n        add_3: "f32[1, 256]" = torch.ops.aten.add.Tensor(mul_3, mul_4);  mul_3 = mul_4 = None\n        tanh_3: "f32[1, 256]" = torch.ops.aten.tanh.default(add_3)\n        mul_5: "f32[1, 256]" = torch.ops.aten.mul.Tensor(sigmoid_5, tanh_3);  sigmoid_5 = tanh_3 = None\n        stack: "f32[1, 1, 256]" = torch.ops.aten.stack.default([mul_5], 1)\n        stack_1: "f32[2, 1, 256]" = torch.ops.aten.stack.default([mul_2, mul_5]);  mul_2 = mul_5 = None\n        stack_2: "f32[2, 1, 256]" = torch.ops.aten.stack.default([add_1, add_3]);  add_1 = add_3 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:775 in forward, code: val, hidden0, hidden1 = self._lstm(\n        transpose_2: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_1, 0, 1);  stack_1 = None\n        transpose_3: "f32[1, 2, 256]" = torch.ops.aten.transpose.int(stack_2, 0, 1);  stack_2 = None\n        stack_3: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_2], 1);  transpose_2 = None\n        stack_4: "f32[1, 1, 2, 256]" = torch.ops.aten.stack.default([transpose_3], 1);  transpose_3 = None\n        \n         # File: /pytorch/rl/torchrl/modules/tensordict_module/rnn.py:788 in forward, code: tensordict.update(tensordict_shaped.reshape(shape))\n        reshape: "f32[1, 32]" = torch.ops.aten.reshape.default(unsqueeze_4, [1, 32]);  unsqueeze_4 = None\n        reshape_1: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_3, [1, 2, 256]);  stack_3 = None\n        reshape_2: "f32[1, 2, 256]" = torch.ops.aten.reshape.default(stack_4, [1, 2, 256]);  stack_4 = None\n        reshape_3: "b8[1]" = torch.ops.aten.reshape.default(unsqueeze_7, [1]);  unsqueeze_7 = None\n        reshape_4: "f32[1, 256]" = torch.ops.aten.reshape.default(stack, [1, 256]);  stack = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_4: "f32[1, 64]" = torch.ops.aten.linear.default(reshape_4, p_module_2_module_0_weight, p_module_2_module_0_bias);  p_module_2_module_0_weight = p_module_2_module_0_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)\n        tanh_4: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_4);  linear_4 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_5: "f32[1, 64]" = torch.ops.aten.linear.default(tanh_4, p_module_2_module_2_weight, p_module_2_module_2_bias);  tanh_4 = p_module_2_module_2_weight = p_module_2_module_2_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:392 in forward, code: return torch.tanh(input)\n        tanh_5: "f32[1, 64]" = torch.ops.aten.tanh.default(linear_5);  linear_5 = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)\n        linear_6: "f32[1, 5]" = torch.ops.aten.linear.default(tanh_5, p_module_2_module_4_weight, p_module_2_module_4_bias);  tanh_5 = p_module_2_module_4_weight = p_module_2_module_4_bias = None\n        \n         # File: /pytorch/rl/env/lib/python3.10/site-packages/tensordict/nn/sequence.py:540 in forward, code: tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)\n        squeeze_1: "f32[32]" = torch.ops.aten.squeeze.dim(reshape, 0);  reshape = squeeze_1 = None\n        squeeze_2: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_1, 0);  reshape_1 = None\n        squeeze_3: "f32[2, 256]" = torch.ops.aten.squeeze.dim(reshape_2, 0);  reshape_2 = None\n        squeeze_4: "b8[]" = torch.ops.aten.squeeze.dim(reshape_3, 0);  reshape_3 = squeeze_4 = None\n        squeeze_5: "f32[256]" = torch.ops.aten.squeeze.dim(reshape_4, 0);  reshape_4 = squeeze_5 = None\n        squeeze_6: "f32[5]" = torch.ops.aten.squeeze.dim(linear_6, 0);  linear_6 = None\n        return (squeeze_6, squeeze_2, squeeze_3)\n        '

AOTInductor:将您的策略导出到无 pytorch 的 C++ 二进制文件

AOTInductor 是一个 PyTorch 模块,允许您将模型(策略或其他)导出为无 pytorch 的 C++ 二进制文件。 当您需要在 PyTorch 不可用的设备或平台上部署模型时,这尤其有用。

以下是如何使用 AOTInductor 导出策略的示例,其灵感来自 AOTI 文档

from tempfile import TemporaryDirectory

from torch._inductor import aoti_compile_and_package, aoti_load_package

with TemporaryDirectory() as tmpdir:
    path = str(Path(tmpdir) / "model.pt2")
    with torch.no_grad():
        pkg_path = aoti_compile_and_package(
            exported_policy,
            # Specify the generated shared library path
            package_path=path,
        )
    print("pkg_path", pkg_path)

    compiled_module = aoti_load_package(pkg_path)

print(compiled_module(pixels=pixels))
Traceback (most recent call last):
  File "/pytorch/rl/docs/source/reference/generated/tutorials/export.py", line 351, in <module>
    compiled_module = aoti_load_package(pkg_path)
  File "/pytorch/rl/env/lib/python3.10/site-packages/torch/_inductor/__init__.py", line 196, in aoti_load_package
    return load_package(path)
  File "/pytorch/rl/env/lib/python3.10/site-packages/torch/_inductor/package/package.py", line 287, in load_package
    loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name)  # type: ignore[call-arg]
RuntimeError: Error in dlopen: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /tmp/DB4zgV/data/aotinductor/model/cpzj3tnsfucylu2xqny7ltsfgtvx3w226y6duqknjgtewcbybcre.so)

使用 ONNX 导出 TorchRL 模型

注意

要执行这部分脚本,请确保已安装 pytorch onnx:

!pip install onnx-pytorch
!pip install onnxruntime

您还可以在此处找到有关在 PyTorch 生态系统中使用 ONNX 的更多信息。以下示例基于此 文档。

在本节中,我们将展示如何以 在无 PyTorch 设置上执行。

Web 上有大量资源解释了如何使用 ONNX 在各种 硬件和设备,包括 Raspberry PiNVIDIA TensorRTiOSAndroid

我们训练的 Atari 游戏可以在没有 TorchRL 或带有 ALE 库的 gymnasium 的情况下进行隔离,因此为我们提供了 一个很好的例子,说明我们可以使用 ONNX 实现什么。

让我们看看这个 API 是什么样子的:

from ale_py import ALEInterface, roms

# Create the interface
ale = ALEInterface()
# Load the pong environment
ale.loadROM(roms.Pong)
ale.reset_game()

# Make a step in the simulator
action = 0
reward = ale.act(action)
screen_obs = ale.getScreenRGB()
print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape)

from matplotlib import pyplot as plt

plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.imshow(screen_obs)
plt.title("Screen rendering of Pong game.")

导出到 ONNX 与上面的 Export/AOTI 非常相似:

import onnxruntime

with set_exploration_type("DETERMINISTIC"):
    # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model
    pixels = torch.as_tensor(screen_obs)
    onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels)

我们现在可以将程序保存在磁盘上并加载它:

with TemporaryDirectory() as tmpdir:
    onnx_file_path = str(Path(tmpdir) / "policy.onnx")
    onnx_policy_export.save(onnx_file_path)

    ort_session = onnxruntime.InferenceSession(
        onnx_file_path, providers=["CPUExecutionProvider"]
    )

onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
onnx_policy = ort_session.run(None, onnxruntime_input)

使用 ONNX 运行推出

现在,我们有一个运行策略的 ONNX 模型。让我们将它与原始的 TorchRL 实例进行比较:因为它是 更轻量级,则 ONNX 版本应该比 TorchRL 版本更快。

def onnx_policy(screen_obs: np.ndarray) -> int:
    onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs}
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
    action = int(onnxruntime_outputs[0])
    return action


with timeit("ONNX rollout"):
    num_steps = 1000
    ale.reset_game()
    for _ in range(num_steps):
        screen_obs = ale.getScreenRGB()
        action = onnx_policy(screen_obs)
        reward = ale.act(action)

with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
    env.rollout(num_steps, policy_explore)

print(timeit.print())

请注意,ONNX 还提供了直接优化模型的可能性,但这超出了本文的范围 教程。

结论

在本教程中,我们学习了如何使用各种后端(例如 PyTorch 的内置导出)导出 TorchRL 模块 functionality、 和 . 我们演示了如何导出在 Atari 游戏上训练的策略,并使用该库在无 pytorch 设置上运行它。我们还将原始 TorchRL 实例的性能与导出的 ONNX 模型进行了比较。AOTInductorONNXALE

关键要点:

  • 导出 TorchRL 模块允许在未安装 PyTorch 的设备上进行部署。

  • AOTInductor 和 ONNX 为导出模型提供了替代后端。

  • 优化 ONNX 模型可以提高性能。

进一步阅读和学习步骤:

  • 有关详细信息,请查看 PyTorch 的导出功能AOTInductorONNX 的官方文档 信息。

  • 尝试在不同设备上部署导出的模型。

  • 探索 ONNX 模型的优化技术以提高性能。

脚本总运行时间:(0 分 58.085 秒)

估计内存使用量:4641 MB

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源