Note
Go to the end to download the full example code.
Simplifying PyTorch Memory Management with TensorDict¶
Author: Tom Begley
In this tutorial you will learn how to control where the contents of a
TensorDict are stored in memory, either by sending those contents to a device,
or by utilizing memory maps.
Devices¶
When you create a TensorDict, you can specify a device with the device
keyword argument. If the device is set, then all entries of the
TensorDict will be placed on that device. If the device is not set, then
there is no requirement that entries in the TensorDict must be on the same
device.
In this example we instantiate a TensorDict with device="cuda:0". When
we print the contents we can see that they have been moved onto the device.
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=cuda:0,
is_shared=True)
If the device of the TensorDict is not None, new entries are also moved
onto the device.
>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=cuda:0,
is_shared=True)
You can check the current device of the TensorDict with the device
attribute.
>>> print(tensordict.device)
cuda:0
The contents of the TensorDict can be sent to a device like a PyTorch tensor
with TensorDict.cuda() or
TensorDict.device(device) with device
being the desired device.
>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=cuda:0,
is_shared=True)
The TensorDict.device method requires a valid
device to be passed as the argument. If you want to remove the device from the
TensorDict to allow values with different devices, you should use the
TensorDict.clear_device method.
>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
Memory-mapped Tensors¶
tensordict provides a class MemoryMappedTensor
which allows us to store the contents of a tensor on disk, while still
supporting fast indexing and loading of the contents in batches.
See the ImageNet Tutorial for an
example of this in action.
To convert the TensorDict to a collection of memory-mapped tensors, use the
TensorDict.memmap_.
tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()
print(tensordict)
TensorDict(
fields={
a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
Alternatively one can use the
TensorDict.memmap_like method. This will
create a new TensorDict of the same structure with
MemoryMappedTensor values, however it will not copy the
contents of the original tensors to the
memory-mapped tensors. This allows you to create the memory-mapped
TensorDict and then populate it slowly, and hence should generally be
preferred to memmap_.
tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()
print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
By default the contents of the TensorDict will be saved to a temporary
location on disk, however if you would like to control where they are saved you can
use the keyword argument prefix="/path/to/root".
The contents of the TensorDict are saved in a directory structure that mimics
the structure of the TensorDict itself. The contents of the tensor is saved
in a NumPy memmap, and the metadata in an associated PyTorch save file. For example,
the above TensorDict is saved as follows:
├── a.memmap
├── a.meta.pt
├── b
│ ├── c.memmap
│ ├── c.meta.pt
│ └── meta.pt
└── meta.pt
Total running time of the script: (0 minutes 0.004 seconds)