Note
Go to the end to download the full example code.
Introduction to TorchRL¶
This demo was presented at ICML 2022 on the industry demo day.
It gives a good overview of TorchRL functionalities. Feel free to reach out to vmoens@fb.com or submit issues if you have questions or comments about it.
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.
The PyTorch ecosystem team (Meta) has decided to invest in that library to provide a leading platform to develop RL solutions in research settings.
It provides pytorch and python-first, low and high level abstractions # for RL that are intended to be efficient, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar (torchrl/envs), transforms, models, data utilities (e.g. collectors and containers), etc. TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional.
Unlike other domains, RL is less about media than algorithms. As such, it is harder to make truly independent components.
What TorchRL is not:
- a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, but we provide these algorithms only as examples of how to use the library. 
- a research framework: modularity in TorchRL comes in two flavours. First, we try to build re-usable components, such that they can be easily swapped with each other. Second, we make our best such that components can be used independently of the rest of the library. 
TorchRL has very few core dependencies, predominantly PyTorch and numpy. All other dependencies (gym, torchvision, wandb / tensorboard) are optional.
Data¶
TensorDict¶
import torch
from tensordict import TensorDict
Let’s create a TensorDict.
batch_size = 5
tensordict = TensorDict(
    source={
        "key 1": torch.zeros(batch_size, 3),
        "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)
print(tensordict)
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)
You can index a TensorDict as well as query keys.
print(tensordict[2])
print(tensordict["key 1"] is tensordict.get("key 1"))
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
True
The following shows how to stack multiple TensorDicts.
tensordict1 = TensorDict(
    source={
        "key 1": torch.zeros(batch_size, 1),
        "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)
tensordict2 = TensorDict(
    source={
        "key 1": torch.ones(batch_size, 1),
        "key 2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)
tensordict = torch.stack([tensordict1, tensordict2], 0)
tensordict.batch_size, tensordict["key 1"]
(torch.Size([2, 5]), tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.]],
        [[1.],
         [1.],
         [1.],
         [1.],
         [1.]]]))
Here are some other functionalities of TensorDict.
print(
    "view(-1): ",
    tensordict.view(-1).batch_size,
    tensordict.view(-1).get("key 1").shape,
)
print("to device: ", tensordict.to("cpu"))
# print("pin_memory: ", tensordict.pin_memory())
print("share memory: ", tensordict.share_memory_())
print(
    "permute(1, 0): ",
    tensordict.permute(1, 0).batch_size,
    tensordict.permute(1, 0).get("key 1").shape,
)
print(
    "expand: ",
    tensordict.expand(3, *tensordict.batch_size).batch_size,
    tensordict.expand(3, *tensordict.batch_size).get("key 1").shape,
)
view(-1):  torch.Size([10]) torch.Size([10, 1])
to device:  TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([2, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([2, 5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2, 5]),
    device=cpu,
    is_shared=False)
share memory:  TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([2, 5, 1]), device=cpu, dtype=torch.float32, is_shared=True),
        key 2: Tensor(shape=torch.Size([2, 5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=True)},
    batch_size=torch.Size([2, 5]),
    device=None,
    is_shared=True)
permute(1, 0):  torch.Size([5, 2]) torch.Size([5, 2, 1])
expand:  torch.Size([3, 2, 5]) torch.Size([3, 2, 5, 1])
You can create a nested TensorDict as well.
tensordict = TensorDict(
    source={
        "key 1": torch.zeros(batch_size, 3),
        "key 2": TensorDict(
            source={"sub-key 1": torch.zeros(batch_size, 2, 1)},
            batch_size=[batch_size, 2],
        ),
    },
    batch_size=[batch_size],
)
tensordict
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: TensorDict(
            fields={
                sub-key 1: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)
Replay buffers¶
from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer
rb = ReplayBuffer(collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
[1]
rb.extend([2, 3])
rb.sample(3)
[2, 1, 3]
rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)
Here are examples of using a replaybuffer with tensordicts.
collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)
1
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())
3
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer
rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
tensordict_sample = rb.sample(2).contiguous()
tensordict_sample
TensorDict(
    fields={
        _weight: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        index: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
tensordict_sample["index"]
tensor([0, 0])
tensordict_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(tensordict_sample)
for i, val in enumerate(rb._sampler._sum_tree):
    print(i, val)
    if i == len(rb):
        break
try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym
0 0.28791671991348267
1 1.0
2 0.0
Envs¶
from torchrl.envs.libs.gym import GymEnv, GymWrapper
gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")
tensordict = env.reset()
env.rand_step(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
Changing environments config¶
env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        pixels: Tensor(shape=torch.Size([500, 500, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
env.close()
del env
from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    ToTensorImage,
    TransformedEnv,
)
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['pixels']),
            ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels'])))
Transforms¶
from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['pixels']),
            ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels'])))
env.reset()
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        pixels: Tensor(shape=torch.Size([3, 500, 500]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
print("env: ", env)
print("last transform parent: ", env.transform[2].parent)
env:  TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['pixels']),
            ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels'])))
last transform parent:  TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['pixels'])))
Vectorized Environments¶
from torchrl.envs import ParallelEnv
base_env = ParallelEnv(
    4,
    lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False),
    mp_start_method="fork",  # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
    base_env, Compose(StepCounter(), ToTensorImage())
)  # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        pixels: Tensor(shape=torch.Size([4, 3, 500, 500]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)
print(env.action_spec)
env.close()
del env
BoundedTensorSpec(
    shape=torch.Size([4, 1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)
Modules¶
Models¶
Example of a MLP model:
from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims
net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)
MLP(
  (0): LazyLinear(in_features=0, out_features=32, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=32, out_features=64, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=64, out_features=4, bias=True)
)
torch.Size([10, 4])
Example of a CNN model:
cnn = ConvNet(
    num_cells=[32, 64],
    kernel_sizes=[8, 4],
    strides=[2, 1],
    aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape)  # last tensor is squashed
ConvNet(
  (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(2, 2))
  (1): ELU(alpha=1.0)
  (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1))
  (3): ELU(alpha=1.0)
  (4): SquashDims()
)
torch.Size([10, 6400])
TensorDictModules¶
from tensordict.nn import TensorDictModule
tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key 1"], out_keys=["key 2"])
td_module(tensordict)
print(tensordict)
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)
Sequences of Modules¶
from tensordict.nn import TensorDictSequential
backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
    backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])
sequence = TensorDictSequential(backbone, actor, value)
print(sequence)
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=Linear(in_features=5, out_features=3, bias=True),
          device=cpu,
          in_keys=['observation'],
          out_keys=['hidden'])
      (1): TensorDictModule(
          module=Linear(in_features=3, out_features=4, bias=True),
          device=cpu,
          in_keys=['hidden'],
          out_keys=['action'])
      (2): TensorDictModule(
          module=MLP(
            (0): LazyLinear(in_features=0, out_features=4, bias=True)
            (1): Tanh()
            (2): Linear(in_features=4, out_features=5, bias=True)
            (3): Tanh()
            (4): Linear(in_features=5, out_features=1, bias=True)
          ),
          device=cpu,
          in_keys=['hidden', 'action'],
          out_keys=['value'])
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['hidden', 'action', 'value'])
print(sequence.in_keys, sequence.out_keys)
['observation'] ['hidden', 'action', 'value']
tensordict = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
backbone(tensordict)
actor(tensordict)
value(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
tensordict = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
sequence(tensordict)
print(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
Functional Programming (Ensembling / Meta-RL)¶
from tensordict import TensorDict
params = TensorDict.from_module(sequence)
print("extracted params", params)
extracted params TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                bias: Parameter(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                                weight: Parameter(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([]),
                            device=None,
                            is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                1: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                                weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([]),
                            device=None,
                            is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                2: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                0: TensorDict(
                                    fields={
                                        bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                                        weight: Parameter(shape=torch.Size([4, 7]), device=cpu, dtype=torch.float32, is_shared=False)},
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False),
                                2: TensorDict(
                                    fields={
                                        bias: Parameter(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
                                        weight: Parameter(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False),
                                4: TensorDict(
                                    fields={
                                        bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                                        weight: Parameter(shape=torch.Size([1, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False)},
                            batch_size=torch.Size([]),
                            device=None,
                            is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
functional call using tensordict:
with params.to_module(sequence):
    sequence(tensordict)
Using vectorized map for model ensembling
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([4, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([4, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        value: Tensor(shape=torch.Size([4, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
Specialized Classes¶
torch.manual_seed(0)
from torchrl.data import BoundedTensorSpec
from torchrl.modules import SafeModule
spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
    module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(tensordict)["action"]
tensor([-0.0137,  0.1524, -0.0641], grad_fn=<ViewBackward0>)
tensordict = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(tensordict)["action"]  # safe=True projects the result within the set
tensor([-1.,  1., -1.], grad_fn=<AsStridedBackward0>)
from torchrl.modules import Actor
base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(tensordict)  # action is the default value
from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
# Probabilistic modules
from torchrl.modules import NormalParamExtractor, TanhNormal
td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
    nn.Linear(5, 4), NormalParamExtractor()
)  # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=False,
    ),
)
td_module(td)
print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=True,
    ),
)
td_module(td)
print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
# Sampling vs mode / mean
from torchrl.envs.utils import ExplorationType, set_exploration_type
td = TensorDict({"input": torch.randn(3, 5)}, [3])
torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
    td_module(td)
    print("random:", td["action"])
with set_exploration_type(ExplorationType.MODE):
    td_module(td)
    print("mode:", td["action"])
with set_exploration_type(ExplorationType.MODE):
    td_module(td)
    print("mean:", td["action"])
random: tensor([[ 0.8728, -0.1334],
        [-0.9833,  0.3494],
        [-0.6887, -0.6402]], grad_fn=<_SafeTanhBackward>)
mode: tensor([[-0.1132,  0.1762],
        [-0.3430, -0.2668],
        [ 0.2918,  0.6239]], grad_fn=<_SafeTanhBackward>)
mean: tensor([[-0.1132,  0.1762],
        [-0.3430, -0.2668],
        [ 0.2918,  0.6239]], grad_fn=<_SafeTanhBackward>)
Using Environments and Modules¶
from torchrl.envs.utils import step_mdp
env = GymEnv("Pendulum-v1")
action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
    actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
tensordict = env.reset()
tensordicts = TensorDict({}, [max_steps])
for i in range(max_steps):
    actor(tensordict)
    tensordicts[i] = env.step(tensordict)
    if tensordict["done"].any():
        break
    tensordict = step_mdp(tensordict)  # roughly equivalent to obs = next_obs
tensordicts_prealloc = tensordicts.clone()
print("total steps:", i)
print(tensordicts)
total steps: 99
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([100]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([100]),
    device=None,
    is_shared=False)
# equivalent
torch.manual_seed(0)
env.set_seed(0)
max_steps = 100
tensordict = env.reset()
tensordicts = []
for _ in range(max_steps):
    actor(tensordict)
    tensordicts.append(env.step(tensordict))
    if tensordict["done"].any():
        break
    tensordict = step_mdp(tensordict)  # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(tensordicts, 0)
print("total steps:", i)
print(tensordicts_stack)
total steps: 99
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([100]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([100]),
    device=None,
    is_shared=False)
(tensordicts_stack == tensordicts_prealloc).all()
True
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout
(tensordict_rollout == tensordicts_prealloc).all()
from tensordict.nn import TensorDictModule
Collectors¶
from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv
EnvCreator makes sure that we can send a lambda function from process to process We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited.
parallel_env = SerialEnv(
    3,
    EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
Sync data collector
devices = ["cpu", "cpu"]
collector = MultiSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([2, 3, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 3, 10]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 3, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2, 3, 10]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2, 3, 10]),
    device=cpu,
    is_shared=False)
3
# async data collector: keeps working while you update your model
collector = MultiaSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([3, 20]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([3, 20]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3, 20]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3, 20]),
    device=cpu,
    is_shared=False)
3
Objectives¶
# TorchRL delivers meta-RL compatible loss functions
# Disclaimer: This APi may change in the future
from torchrl.objectives import DDPGLoss
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])
class ConcatModule(nn.Linear):
    def forward(self, obs, action):
        return super().forward(torch.cat([obs, action], -1))
value_module = ConcatModule(4, 1)
value = TensorDictModule(
    value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)
loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
tensordict = TensorDict(
    {
        "observation": torch.randn(10, 3),
        "next": {
            "observation": torch.randn(10, 3),
            "reward": torch.randn(10, 1),
            "done": torch.zeros(10, 1, dtype=torch.bool),
        },
        "action": torch.randn(10, 1),
    },
    batch_size=[10],
    device="cpu",
)
loss_td = loss_fn(tensordict)
print(loss_td)
TensorDict(
    fields={
        loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        pred_value: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        target_value: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        td_error: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
print(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        td_error: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
State of the Library¶
TorchRL is currently an alpha-release: there may be bugs and there is no guarantee about BC-breaking changes. We should be able to move to a beta-release by the end of the year. Our roadmap to get there comprises:
- Distributed solutions 
- Offline RL 
- Greater support for meta-RL 
- Multi-task and hierarchical RL 
Contributing¶
We are actively looking for contributors and early users. If you’re working in RL (or just curious), try it! Give us feedback: what will make the success of TorchRL is how well it covers researchers needs. To do that, we need their input! Since the library is nascent, it is a great time for you to shape it the way you want!
Installing the Library¶
The library is on PyPI: pip install torchrl
Total running time of the script: (3 minutes 43.926 seconds)
Estimated memory usage: 324 MB