-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[GRPO] Adds an option to scale the loss by a constant factor #3231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a nit and suggestion on whether we should unit test the scaling
@@ -101,6 +101,8 @@ class GRPOConfig(TrainingArguments): | |||
speed, but may be numerically unstable for long training runs. | |||
num_iterations (`int`, *optional*, defaults to `1`): | |||
Number of iterations per batch (denoted as μ in the algorithm). | |||
use_max_tokens_norm (`bool`, *optional*, defaults to `False`): | |||
Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion to clarify what we mean by "maximum possible"
Whether to use the max tokens norm. If `True`, the loss is normalized by a consant, the maximum possible number of tokens | |
Whether to use the max tokens norm. If `True`, the loss is normalized by a constant factor that is determined by the total number of prompt and completions tokens in a batch. |
use_max_tokens_norm: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether to use the max tokens norm. If `True`, the loss is normalized by a constant, the maximum " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto here if you agree with the change above
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() | ||
|
||
if self.use_max_tokens_norm: | ||
loss = (per_token_loss * completion_mask).sum() / self.max_tokens_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how easy it is to unit test this, but would it make sense to do it so that we're sure the loss is being computed as your diagrams show?
E.g. an integration test would be to check that specifying the config params gives the expected scaling for some dummy inputs
@@ -101,6 +101,8 @@ class GRPOConfig(TrainingArguments): | |||
speed, but may be numerically unstable for long training runs. | |||
num_iterations (`int`, *optional*, defaults to `1`): | |||
Number of iterations per batch (denoted as μ in the algorithm). | |||
use_max_tokens_norm (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the loss proposed in Dr GRPO, correct?
If so, I think it should be explicitly mentioned in the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure actually, I thought that was our current implementation. I will take another look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
closing in favor of #3256 |
What does this PR do?
This PR adds an option to scale the loss by a constant factor, equal to the maximum possible tokens in a batch.
The reasoning behind this PR is that I believe that the current normalization scheme, implemented in #2881, is not invariant to the ordering of samples across devices / gradient accumulation steps. Which may cause instabilities in training.
Toy Example
Consider a

DDP=2
setting with aper_device_train_batch_size=4
. For this example, assume that the loss per token is 1.With global normalization:
Here each token has an equal contribution the loss, but only inside the current device. The loss is not comparable to a setting where all the batch is on a single device, for example consider a
DDP=1
setting withper_device_train_batch_size=4
One potential solution would be to gather the number of unmasked tokens across all devices and use this for normalization. But the same issue would also occur across gradient accumulation steps.
Proposed solution
Calculate a constant factor
max_tokens_norm = per_device_train_batch_size * (max_prompt_length +max_completion_length)
and always normalize the loss by this constant factor.The learning rate will probably need to be increased to get comparable results with our other baselines.