torchtune.training¶
Checkpointing¶
torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of checkpointing, please see the checkpointing deep-dive.
| Checkpointer which reads and writes checkpoints in HF's format. | |
| Checkpointer which reads and writes checkpoints in Meta's format. | |
| Checkpointer which reads and writes checkpoints in a format compatible with torchtune. | |
| ModelType is used by the checkpointer to distinguish between different model architectures. | |
| This class gives a more concise way to represent a list of filenames of the format  | |
| Validates the state dict for checkpoint loading for a classifier model. | 
Reduced Precision¶
Utilities for working in a reduced precision setting.
| Get the torch.dtype corresponding to the given precision string. | |
| Context manager to set torch's default dtype. | |
| Validates that all input parameters have the expected dtype. | |
| Given a quantizer object, returns a string that specifies the type of quantization. | 
Distributed¶
Utilities for enabling and working with distributed training.
| A datatype for a function that can be used as an FSDP wrapping policy. | |
| Initialize process group required for  | |
| Check if all environment variables required to initialize torch.distributed are set and distributed is properly installed. | |
| Function that gets the current world size (aka total number of ranks) and rank number of the current process in the default process group. | |
| Retrieves an FSDP wrapping policy based on the specified flags  | |
| A default policy for wrapping models trained with LoRA using FSDP. | |
| Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory | 
Memory Management¶
Utilities to reduce memory consumption during training.
| Utility to setup activation checkpointing and wrap the model for checkpointing. | |
| Utility to apply activation checkpointing to the passed-in model. | |
| A bare-bones class meant for checkpoint save and load for optimizers running in backward. | |
| Create a wrapper for optimizer step running in backward. | |
| Register hooks for optimizer step running in backward. | 
Schedulers¶
Utilities to control lr during the training process.
| Create a learning rate schedule that linearly increases the learning rate from 0.0 to lr over  | |
| Full_finetune_distributed and full_finetune_single_device assume all optimizers have the same LR, here to validate whether all the LR are the same and return if True. | 
Metric Logging¶
Various logging utilities.
| Logger for use w/ Comet (https://www.comet.com/site/). | |
| Logger for use w/ Weights and Biases application (https://wandb.ai/). | |
| Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). | |
| Logger to standard output. | |
| Logger to disk. | 
Performance and Profiling¶
torchtune provides utilities to profile and debug the memory and performance of your finetuning job.
| Computes a memory summary for the passed in device. | |
| Logs a dict containing memory stats to the logger. | |
| Sets up  | 
Miscellaneous¶
| Returns the sequence lengths for each batch element, excluding masked tokens. | |
| Function that sets seed for pseudo-random number generators across commonly used libraries. |