Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix length bias for Dr GRPO #3138

Closed
wants to merge 5 commits into from

Conversation

idoru
Copy link

@idoru idoru commented Mar 23, 2025

This part of S3.1 from https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf appears to be missing?

In the paper, under "Length Bias Also Exists in Open-Source PPO Implementation" they mention that existing masked mean calculation inherited from PPO codebases introduced length bias. They provide a Listing 1 which seems to imply this is what they do in their implementation.

It seems the grpo_trainer.py still had these inherited biases?

See also: https://x.com/zzlccc/status/1903291175420961241

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@idoru
Copy link
Author

idoru commented Mar 23, 2025

Hi @qgallouedec I'm new to this codebase, but I saw your earlier Dr. GRPO merge and happened to be experimenting this this code today, and couldn't help but wonder if this bit was missing?

@qgallouedec
Copy link
Member

Thanks, but I don't understand, why dividing the loss by the max sequence length?

@idoru
Copy link
Author

idoru commented Mar 23, 2025

In the paper, under "Length Bias Also Exists in Open-Source PPO Implementation" they mention that existing masked mean calculation inherited from PPO codebases introduced length bias. They provide a Listing 1 which seems to imply this is what they do in their implementation.

It seems the grpo_trainer.py still had these inherited biases?

https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf

@qgallouedec
Copy link
Member

Actually, I don't think so. Since #2881, we’ve been using global normalization instead of per-sequence normalization. I believe this is equivalent to what is described in the paper. Am I right?

The only thing I'm not 100% sure about is whether we should divide by mask.sum(). In the paper, they recommend using a constant.

we could replace the mask.sum(axis=dim)
with a constant value (e.g., generation budget) in the masked mean function in listing 1, as
highlighted by the line in green

Do they mean constant for the batch? In that case, mask.sum() would fit. Or do they mean constant for the whole training? In that case, mask.sum() wouldn't.

@idoru
Copy link
Author

idoru commented Mar 23, 2025

Re-reading the part you quoted, they do mention generation budget as a constant which would be constant across the whole training?

But I think then maybe my change should only divide by max_completion_length.

They say e.g. as well, does that mean it could be any constant? 🤷

@idoru
Copy link
Author

idoru commented Mar 23, 2025

Updated to use just max gen length now

https://x.com/zzlccc/status/1903674024150081998

BTW thanks for reviewing my changes!

idoru added 2 commits March 23, 2025 18:51
This part of S3.1 from https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf appears to be missing.

Is the max_completion_length used here right?
@idoru idoru force-pushed the dr-grpo-fix-length-bias branch from 0b1abfa to 4329f9c Compare March 23, 2025 22:51
@qgallouedec
Copy link
Member

My current opinion is that there isn't really any data to show how these different ways of normalizing compare. To decide on the relevance of adding this option, more data would be needed. If someone in the community would like to carry out an experiment, that would be very useful.

@lkevinzc
Copy link

lkevinzc commented Mar 25, 2025

We mean any constant during the whole training (max_gen_len is a good choice), so this wouldn't be a per-batch level constant as the .sum() computes. Using batch-level length summation to normalize could only alleviate but not remove the length bias.

I am writing a summary for this, and hopefully could clarify the length bias so that the community could fix it.

@qgallouedec
Copy link
Member

I get it. Let's use max generation length. Or maybe even max_completion_length * gradient_accumualtion_steps keep the consistency when accumulating the gradient

@lkevinzc
Copy link

Sure! Any constant will do to remove the length bias. The scale of the constant will affect the gradient magnitude, so max_completion_length * gradient_accumualtion_steps is a good suggestion!

In fact, the DAPO's loss is something in between. They still use some form of length normalization (question-level) which is biased as well. May I suggest that we remove it to avoid future confusion?

@qgallouedec
Copy link
Member

Yes, I think your approach should be the default. Even maybe the unique way.

@qgallouedec
Copy link
Member

I'll run some experiments tommorow

@lkevinzc
Copy link

Cool @qgallouedec , thank you so much for the efforts in pushing open-source LLM RL forward!

By the way, maybe we could re-considering the name of the flag. I guess scale_rewards could not describe both the std and length bias.

If your experiment results turn out good, maybe we could consider:

  1. Replacing length normalization with constant normalization for all algorithms (including GRPO, PPO) to make them unbiased (and faithfully follow the policy gradient math).
  2. If 1) is set to the default, then using that flag to give an option to scale the reward is okay; though I think this might also become the default in the future as suggested by tobi lutke.

@minosvasilias
Copy link

My current opinion is that there isn't really any data to show how these different ways of normalizing compare. To decide on the relevance of adding this option, more data would be needed. If someone in the community would like to carry out an experiment, that would be very useful.

Sharing some experiments in case they're helpful.

Context: I have been using GRPO for a dataset of string transformations ("Replace word X with Y", "Remove all instances of character X", "Remove all punctuation", etc.) and have experienced consistent collapse across different base models, where the model would sharply devolve into generating very long sequences of gibberish tokens. Long completions with low advantage being penalised less due to the length bias explained in Dr. GRPO could potentially explain/contribute to this.

Below are three sets of runs:

  • Blue: trl==0.16.0 release state:
    • loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
  • Red: trl==0.16.0 state with modified loss from this PR:
    • loss = (per_token_loss * completion_mask).sum() / self.max_completion_length
  • Green: trl==0.16.0 state with old (pre 0.16.0) loss:
    • loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
Screenshot 2025-03-26 at 00 45 43
Hyperparams (same for all runs)
max_prompt_length: 1024
max_completion_length: 2048
per_device_train_batch_size: 4
num_generations: 4
num_iterations: 4
beta: 0.0
gradient_accumulation_steps: 20
learning_rate: 4e-6
max_grad_norm: 0.2
lr_scheduler_type: constant_with_warmup
warmup_steps: 20
adam_beta1: 0.9
adam_beta2: 0.99
num_train_epochs: 1
torch_dtype: bfloat16
bf16: True
scale_rewards: False

There's quite a spread here due to nature of the data and hyperparams so please take with a grain of salt, but trends seem to indicate 0.16.0 collapses most quickly, constant normalization is somewhat more stable, and old loss while very volatile delays format collapse the longest.

@qgallouedec
Copy link
Member

Thanks a lot for sharing this. That's very helpful.

That's rather surprising. What surprises me the most is the red curve. Why would the model start generating long sequences if the objective isn't biased toward long sequence?

@qgallouedec
Copy link
Member

@lkevinzc
Copy link

Hey @minosvasilias , thank you a lot for the results!

Instead of this one you tried

loss = (per_token_loss * completion_mask).sum() / self.max_completion_length

Could you please try this:

loss = ((per_token_loss * completion_mask).sum(-1) / self.max_completion_length).mean()

The difference is that the latter would do a constant normalization over the sequence length then average over samples. This should gives a smaller loss magnitude (depending on the loss.shape[0]) compared to your previous trial.

@qgallouedec qgallouedec requested a review from Copilot March 26, 2025 18:35
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR addresses the length bias issue in the GRPO trainer by modifying the loss normalization when using a completion mask.

  • Updated loss calculation to replace division by the sum of the mask with division based on a fixed maximum completion length.
  • Adjusted the aggregation and averaging of the per-token loss to mitigate length bias.

@minosvasilias
Copy link

@lkevinzc thank you for the suggestion! I should have available compute to do at least one run tonight with that config.
Will report back once i've got results.

@qgallouedec
Copy link
Member

I've gotten some pretty surprising results. Here are the first part comparing the different loss types:

if self.args.loss_type == "drgrpo":
    loss = ((per_token_loss * completion_mask).sum(dim=-1)).mean() / self.max_completion_length
elif self.args.loss_type == "dapo":
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
elif self.args.loss_type == "grpo":
    loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

reward
reward_std
completion_length
loss

@qgallouedec
Copy link
Member

And regarding scaling the reward:

if self.args.scale_rewards:
    advantages = advantages / (std_grouped_rewards + 1e-4)

reward-2
reward_std-2
completion_length-2
W B Chart 26_03_2025, 18_40_46

Maybe the absolute value of the rewards is more important when you don't scale the rewards?

@lkevinzc
Copy link

Thanks for the results! Agree that the reward scale matters when we do not scale that for advantage. I guess to get better stability we could do a global reward normalization (a common trick in RL), which is different from the GRPO's intra-group normalization.

BTW, may I know your task and reward function?

A few observations:

1/ For the first part, it seems that the base policy may already converge to a fixed behavior (not exploratory enough) so the response length nearly doesn't change? If this is the case, it would make the exploitation of the length bias less likely, thus the response lengths for all runs are about the same.

2/ Also for the first part, it's surprising that drgrpo has larger losses than grpo. The denominator of drgrpo loss should always be the largest (max completion length >= active tokens), so the average loss should be smaller?

@minosvasilias
Copy link

Apologies for the delayed update.
Results from runs with loss = ((per_token_loss * completion_mask).sum(-1) / self.max_completion_length).mean() in orange:
Screenshot 2025-04-01 at 18 18 24

Very similar behavior as other runs except the (green) pre-0.16.0 loss.

Looking at the loss graph we can see that it spikes at a somewhat lower magnitude as you predicted @lkevinzc .

Screenshot 2025-04-01 at 18 18 48

@qgallouedec qgallouedec mentioned this pull request Apr 7, 2025
5 tasks
@qgallouedec
Copy link
Member

Closed in favour of #3256

@qgallouedec qgallouedec closed this Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants