Skip to content

[GRPO] Adds an option to scale the loss by a constant factor #3231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

edbeeching
Copy link
Collaborator

@edbeeching edbeeching commented Apr 4, 2025

What does this PR do?

This PR adds an option to scale the loss by a constant factor, equal to the maximum possible tokens in a batch.

The reasoning behind this PR is that I believe that the current normalization scheme, implemented in #2881, is not invariant to the ordering of samples across devices / gradient accumulation steps. Which may cause instabilities in training.

Toy Example

Consider a DDP=2 setting with a per_device_train_batch_size=4. For this example, assume that the loss per token is 1.
With global normalization:
image

Here each token has an equal contribution the loss, but only inside the current device. The loss is not comparable to a setting where all the batch is on a single device, for example consider a DDP=1 setting with per_device_train_batch_size=4

image

One potential solution would be to gather the number of unmasked tokens across all devices and use this for normalization. But the same issue would also occur across gradient accumulation steps.

Proposed solution

Calculate a constant factor max_tokens_norm = per_device_train_batch_size * (max_prompt_length +max_completion_length) and always normalize the loss by this constant factor.

image

The learning rate will probably need to be increased to get comparable results with our other baselines.

@edbeeching edbeeching requested review from qgallouedec and lewtun April 4, 2025 10:16
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a nit and suggestion on whether we should unit test the scaling

@@ -101,6 +101,8 @@ class GRPOConfig(TrainingArguments):
speed, but may be numerically unstable for long training runs.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
use_max_tokens_norm (`bool`, *optional*, defaults to `False`):
Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion to clarify what we mean by "maximum possible"

Suggested change
Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens
Whether to use the max tokens norm. If `True`, the loss is normalized by a constant factor that is determined by the total number of prompt and completions tokens in a batch.

use_max_tokens_norm: bool = field(
default=False,
metadata={
"help": "Whether to use the max tokens norm. If `True`, the loss is normalized by a constant, the maximum "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here if you agree with the change above

loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

if self.use_max_tokens_norm:
loss = (per_token_loss * completion_mask).sum() / self.max_tokens_norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how easy it is to unit test this, but would it make sense to do it so that we're sure the loss is being computed as your diagrams show?

E.g. an integration test would be to check that specifying the config params gives the expected scaling for some dummy inputs

@@ -101,6 +101,8 @@ class GRPOConfig(TrainingArguments):
speed, but may be numerically unstable for long training runs.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
use_max_tokens_norm (`bool`, *optional*, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the loss proposed in Dr GRPO, correct?
If so, I think it should be explicitly mentioned in the doc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure actually, I thought that was our current implementation. I will take another look.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we use a modified version of DAPO where we normalize per local batch (and not per group).

Gm-URXOagAABJH3

It the above figure, we use something between BNPO (hard to implement with grad accum) and DAPO

@qgallouedec qgallouedec mentioned this pull request Apr 7, 2025
5 tasks
@edbeeching
Copy link
Collaborator Author

closing in favor of #3256

@edbeeching edbeeching closed this Apr 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants