Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(trainer): Support multi-role & consecutive turns in DataCollatorForCompletionOnlyLM (#3223) #3224

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Kirili4ik
Copy link

@Kirili4ik Kirili4ik commented Apr 3, 2025

What does this PR do?

This PR refactors the masking logic in DataCollatorForCompletionOnlyLM to correctly handle conversations with multiple instruction roles (e.g., user, tool) and consecutive assistant/question turns. This enables the collator's use for more complex dialogue formats like agent trajectories.

Motivation & Context:

Previously, the collator assumed a strict alternation of a single instruction template and a response template (e.g., User -> Assistant). This limitation prevented its use for common scenarios in modern LLM fine-tuning:

  1. Multi-Role Dialogues: It failed on datasets involving more than just user/assistant, such as agent interactions with tool calls (<|im_start|>tool).
  2. Consecutive Turns: It couldn't correctly mask sequences where the assistant speaks multiple times in a row (e.g., Assistant -> Assistant).

This addresses the need for better support for agent tuning data and multi-turn formats, as mentioned in issues #1994 and #2545.

Changes Implemented:

This commit addresses the limitations by:

  • Updating DataCollatorForCompletionOnlyLM.__init__ to accept a list of strings or pre-tokenized IDs for instruction_template, allowing multiple distinct instruction roles to be specified.
  • Rewriting the core masking logic in torch_call:
    • It now correctly identifies all occurrences of the response template and all specified instruction templates within the input sequence.
    • For each assistant response, it unmasks tokens starting from the end of its template up to the beginning of the next identified instruction template (or the sequence end if no further instruction follows).
    • It correctly handles consecutive assistant turns by ensuring the template tokens (e.g., <|im_start|>assistant\n) of subsequent responses remain masked, while their content is unmasked for loss calculation.
  • Adding unit tests (test_masking_* in test_data_collator_completion_only.py) covering:
    • Basic multi-turn conversations.
    • Multi-role scenarios using multiple instruction templates.
    • Correct masking for consecutive assistant messages.
    • Behavior with left-padding enabled.
    • Initialization using pre-tokenized template IDs.

This update significantly increases the flexibility of DataCollatorForCompletionOnlyLM, making it suitable for processing conversational data commonly found in ChatML formats and agent fine-tuning datasets.

My guess is that it can be done more efficiently (harder to understand tho), but the amounts of special tokens does not seem that high to me to optimise it further.

Related: #1994, #2545

Fixes #3223

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…ForCompletionOnlyLM (huggingface#3223)

Refactors the masking logic in `DataCollatorForCompletionOnlyLM` to correctly handle conversations with multiple instruction roles (e.g., user, tool) and consecutive assistant turns, enabling its use for more complex dialogue formats like agent trajectories.

Previously, the collator assumed a strict alternation of a single instruction template and a response template (e.g., User -> Assistant). This failed for:
1.  Datasets with multiple instruction roles (e.g., user prompts and tool calls).
2.  Sequences with consecutive assistant messages (e.g., Assistant -> Assistant).

This commit addresses these limitations:
- Updates `__init__` to accept a list of strings or pre-tokenized IDs for `instruction_template`, allowing multiple distinct instruction roles.
- Rewrites the core masking logic in `torch_call`:
    - It now identifies all occurrences of response and all specified instruction templates.
    - For each assistant response, it unmasks tokens from the end of its template up to the beginning of the *next* instruction template or the sequence end.
    - Correctly handles consecutive assistant turns by masking the template tokens of subsequent responses while unmasking their content.
- Adds comprehensive unit tests (`test_masking_*`) covering multi-role scenarios, consecutive assistant messages, left-padding, and initialization with tokenized templates.

This allows `DataCollatorForCompletionOnlyLM` to process conversational data commonly found in ChatML formats and agent fine-tuning datasets.

Related: huggingface#1994, huggingface#2545
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support masking of different and repeating roles DataCollatorForCompletionOnlyLM
1 participant