Dr. GRPO with Gradient Regularization

Implementing Dr. GRPO and Gradient Regularization from scratch

Methods Used

For this project, I built a Dr.GRPO based RL pipeline - with an explicit Gradient Regularization implementation that has empirically been shown to outperform the implicit regularization incurred by the Kullback-Leibler Divergence through Reference Resets.

A quick clarification on terminology - Dr.GRPO here refers to (https://arxiv.org/abs/2503.20783), or “GRPO Done Right”. The paper illustrates the innate bias present within the normalization terms presented in the original GRPO algorithm. Specifically, dividing by the response length incurs a bias favoring brevity in correct answers, and verbosity for wrong answers. This is because these normalization values are multiplied by the advantage - which is positive when the answer is correct, and negative when incorrect. The “penalty” for a wrong answer is reduced if the answer is longer, since the negative advantage gets divided by a larger number; conversely, for correct answers, a shorter response produces a greater reward. Furthermore, additional bias is accumulated when dividing by the standard deviation during the advantage calculation itself, which the paper referred to as “question-level difficulty bias”. Questions with lower standard deviation (which might happen if the question is too easy - resulting in all correct answers within a group - or too hard, where all the answers are incorrect) are consequently given a larger parameter update, since they’re being divided by a smaller value.

Gradient Regularization (https://arxiv.org/pdf/2602.18037) puts an emphasis on ensuring stability within gradient step updates during training. Rather than penalizing the weights directly (as is done via l1/l2 norms), gradient regularization instead work to push the weights to seek a flatter local minima within the loss landscape by penalizing larger gradients (which are observed in more spikey local minimas). This encourages generalization, and in the case for Proxy Reward settings (such as LLM-as-a-Judge), improves proxy rewards as well!

Motivation

I found this paper very interesting as the proposition that GR techniques improve generalizability by biasing “flat rewards” (and how these “flat rewards” consequently imply a robust policy) would prove to be very useful when training on domains requiring this ability to generalize. Furthermore, the paper claims that GR mitigates ‘Reward Hacking’ in these Proxy Reward settings. This combination of claims was what lead me to actually implement this paper, as I believe they would lead to far better performance on science based tasks. While deterministic regimes such as coding and math have strictly verifiable rewards that circumvent the need for LLM-as-a-Judge proxys, domains that rely on them greatly suffer from inaccurate proxy rewards, largely in part due to Reward Hacking (https://arxiv.org/abs/2507.08794, https://arxiv.org/pdf/2507.17849, https://arxiv.org/pdf/2409.12822).

The paper itself focused primarily on 2 datasets - GSM8k and AlpacaFarm. The former, Grade School Math 8k, would not serve to establish GR’s performance gains in Proxy Reward settings. On the other hand, AlpacaFarm focuses on learning from human feedback - which falls short in it’s ability to measure GR’s performance gains for generalizability (although, strictly speaking, measuring “generalizability” is a tall ask and an active area of research by it’s own right… I will probably just resort to adverserial perturbations as a preliminary measure).

Problem Formulation

That prompted me to further investigate with my own experiments, specifically within the science domain ! I would have liked to have trained on the open-ended subset of the Dr. SCI dataset (https://arxiv.org/pdf/2602.08321) as that seems perfect for the task at hand. However, seeing as the dataset is yet to become public, I’ve pivoted to the 60k biology+chemistry+physics dataset released by CamelAI. It is not an MCQ based dataset, so by working through the long reasoning chains required for this difficult dataset, the trained models should hopefully demonstrate a relatively deeper understanding of these concepts. The trained model will then be compared with a Reference-Reset based GRPO model and a baseline GRPO model (ablating the Gradient Regularization aspect) on the MMLU-Pro dataset. (edit 3/15: As of today, there does seem to be HuggingFace dataset repo that follows the Dr. SCI paper to generate+source the dataset themselves [https://huggingface.co/datasets/MiniByte-666/Dr.SCI] ! However, I had unfortunately already finished my training run by the time this was uploaded.. so the rest of my analysis is based on the CamelAI dataset.)

Notes during Implementation

This project was an interesting start, I ended up going down several rabbit holes that I did not initially plan on. FSDP was sort of an afterthought - I had heard of it before, but I had never looked too deeply into it. When I started looking into actually conducting these training runs myself, it became apparent I needed some form of data parallelization to ensure I didn’t blow my wallet on GPUs. It’s interaction with PEFT methods utilizing quantization (QLORA), and the resulting required bnb config came as an interesting observation following my attempted training runs. I learned a lot from this project, both implementation related, and theory related as well !

If I may be perfectly candid, one thing that stood out to me during this process were the subtleties that went into the implementation of these systems that did not stand out when reading the papers themselves. Perhaps this is slightly indicative of my lack of exposure to implementation (which I am attempting to change :) My background in Machine Learning stemmed from naïve numpy implementations of Transformers and all the logic within the underlying layers during my MS. From there, upon joining the industry, I was quickly transitioned into a higher level, abstracted view involving just making calls to the LLM. I had always been obsessed with building from scratch, so I was blindsided by all the abstractions introduced by PyTorch to make my life a lot easier.

Of course, I knew that during training we introduced dropout layers to further regularize the weights and prevent overspecialization of neurons, and how we backprop the gradient from the final loss using chain rule. What was missing from those papers and lectures was that for evaluation during training, we would have to explicitly turn off dropout layers, or that PyTorch automatically calculates the gradients for every calculation performed to generate the loss - so we have to turn off this autograd in specific places during the flow of decoding (when sampling rollouts), or reshaping tensors to operate in required dimensions, or the n other minutiae related to actually implementing these systems.

Admittedly, it was slightly demoralizing when first starting the implementation process. I knew all the math behind it, why was it so hard just to figure out which of the n PyTorch functions should be used ? Initially, my ego bristled at the thought of using AI to even figure out what function to use; it felt like cheating. Eventually, I came to terms with it. If my goal was to learn how to build intricate systems, this was the best way. I resigned to using it as a supporting aide, and used it to double check implementations.

For example, I had tried implementing the cross entropy loss myself (as I didn’t know there was a PyTorch function for it …) by taking the log softmax, and using the torch.gather function to isolate the log probs of the ground truth token. After verifying my code with AI, I found that while my implementation was technically sound, it possessed a major computational inefficiency ! It stemmed from me calling on F.log_softmax(shift_logits, dim=-1), which I found out computes the log prob for every token in the vocab, thus materializing an enormous intermediate tensor of shape [batch_size, seq_len, vocab_size] in GPU memory. This was an inefficiency that the proposed solution, using -F.cross_entropy, completely avoids by using a fused CUDA kernel to compute the necessary normalization on the fly without storing the entire intermediate tensor in memory.