Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩺 Dr. GRPO loss #3256

Merged
merged 18 commits into from
Apr 9, 2025
Merged

🩺 Dr. GRPO loss #3256

merged 18 commits into from
Apr 9, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Apr 7, 2025

What does this PR do?

This PR supersedes #3231 #3138
Closes #3178

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec qgallouedec marked this pull request as ready for review April 7, 2025 17:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member Author

qgallouedec commented Apr 7, 2025

Same effective batch size (256)

GPUs  Grad accum steps Per device batch size
1 2 128
1 4 64
1 8 32
1 16 16
1 32 8
2 1 128
2 2 64
2 4 32
2 8 16
2 16 8
4 1 64
4 2 32
4 4 16
4 8 8
8 1 32
8 2 16
8 4 8
Screenshot 2025-04-07 at 12 01 43
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

dataset = load_dataset("trl-lib/tldr", split="train[:500]")

# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c)) for c in completions]

ga = 2
bs = 16

args = GRPOConfig(
    output_dir=f"DrGRPO_bs{bs}_ga_{ga}_4GPU",
    per_device_train_batch_size=bs,
    gradient_accumulation_steps=ga,
    num_train_epochs=1,
    logging_steps=1,
    max_prompt_length=64,
    max_completion_length=64,
    loss_type="drgrpo",
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    args=args,
    reward_funcs=reward_num_unique_chars,
    train_dataset=dataset,
)
trainer.train()

@qgallouedec qgallouedec changed the title Dr. GRPO loss 🩺 Dr. GRPO loss Apr 7, 2025
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some nits and a question about what BNPO refers to

difficulty bias.
applied. The [Dr. GRPO paper](https://huggingface.co/papers/2503.14476) recommends not scaling the rewards,
as scaling by the standard deviation introduces a question-level difficulty bias.
loss_type (`str`, *optional*, defaults to `"bnpo"`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is bnpo? Would be good to have a reference to where it's defined (I thought we had DAPO as the default loss)

Copy link
Member Author

@qgallouedec qgallouedec Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I realized while doing this PR that it wasn't exactly DAPO that was being used, but a variant of BNPO as defined here :

Screenshot 2025-04-08 at 06 27 50

Let me try to clarify here. Losses per token are normalized by

  • GRPO: the length of the sequence
  • DAPO: the average sequence length in the group
  • BNPO: the average sequence length in the batch
  • TRL's BNPO: the average sequence length in the local batch*; this is what I call bnpo in the code, but it's not 100% correct
  • Dr GRPO: by the maximum possible length of the completion

*a batch is made up of num_devices * gradient_accumulations local batches

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special cases:
When

  • per_device_batch_size==num_generations, TRL's BNPO is equivalent to DAPO
  • per_device_batch_size==1, TRL's BNPO is equivalent to GRPO
  • gradient_accumualtion_steps==1 and num_devices=1, TRL's BNPO is equivalent to the actual BNPO.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec Thanks for the comprehensive support! A minor comment for your future consideration: Dr. GRPO does not constrain the constant normalizer to be MAX_LEN (although it's easier to just use that). This can affect the update scale (related to your recent tweet https://x.com/QGallouedec/status/1908741708021457357). In fact, different constant of x in the setting in your tweet can be absorbed into the constant normalizer we propose in the paper, and MAX_LEN is a convenient example.

slightly vary depending on the local batch size, despite a constant effective batch size.
- `"drgrpo"`: Token-level losses are aggregated by normalizing with a global constant. This method was
introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.14476) to eliminate length bias.
The value of the constant corresponds to `max_completion_length`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, @edbeeching was trying something slightly different in #3231 that did local scaling per batch instead of a global constant. Do you know if there's much difference between the two?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are roughly equivalent, I have closed my PR in favor of this one.

@qgallouedec
Copy link
Member Author

qgallouedec commented Apr 8, 2025

Results mostly match, expect for the the loss and the grad norm, Dr GRPO seems to reduce the range, as expected.

Screenshot 2025-04-08 at 06 19 52 Screenshot 2025-04-08 at 06 19 33

@qgallouedec qgallouedec merged commit 5e2e9cb into main Apr 9, 2025
10 checks passed
@qgallouedec qgallouedec deleted the dr-grpo-loss branch April 9, 2025 18:13
@qgallouedec qgallouedec mentioned this pull request Apr 9, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

the loss is dapo not the dr.grpo ?
6 participants