Message Transforms¶
Message transforms perform the conversion of raw sample dictionaries from your dataset into torchtune’s
Message structure. Once you data is represented as Messages, torchtune will handle
tokenization and preparing it for the model.
Configuring message transforms¶
Most of our built-in message transforms contain parameters for controlling input masking (train_on_input),
adding a system prompt (new_system_prompt), and changing the expected column names (column_map).
These are exposed in our dataset builders instruct_dataset() and chat_dataset()
so you don’t have to worry about the message transform itself and can configure this directly from the config.
You can see Example instruct dataset or Example chat dataset for more details.
Custom message transforms¶
If our built-in message transforms do not configure for your particular dataset well,
you can create your own class with full flexibility. Simply inherit from the Transform
class and add your code in the __call__ method.
A simple contrived example would be to take one column from the dataset as the user message and another
column as the model response. Indeed, this is quite similar to InputOutputToMessages.
from torchtune.modules.transforms import Transform
from torchtune.data import Message
from typing import Any, Mapping
from pprint import pprint
class MessageTransform(Transform):
    def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
        messages = [
            Message(
                role="user",
                content=sample["input"],
                masked=True,
                eot=True,
            ),
            Message(
                role="assistant",
                content=sample["output"],
                masked=False,
                eot=True,
            ),
        ]
        return {"messages": messages}
input_sample = {"input": "hello world", "output": "bye world"}
transform = MessageTransform()
output_sample = transform(input_sample)
pprint(output_sample)
# {'messages': [Message(role='user', content=['hello world']),
#               Message(role='assistant', content=['bye world'])]}
See Creating Messages for more details on how to manipulate Message objects.
To use this for your dataset, you must create a custom dataset builder that uses the underlying
dataset class, SFTDataset.
# In data/dataset.py
from torchtune.datasets import SFTDataset
def custom_dataset(tokenizer, **load_dataset_kwargs) -> SFTDataset:
    message_transform = MyMessageTransform()
    return SFTDataset(
        source="json",
        data_files="data/my_data.json",
        split="train",
        message_transform=message_transform,
        model_transform=tokenizer,
        **load_dataset_kwargs,
    )
This can be used directly from the config.
dataset:
  _component_: data.dataset.custom_dataset
Example message transforms¶
- Instruct
 
- Preference