Aligning Large Language Models with BRAIn

Community Article Published June 11, 2024

TL;DR: We introduce BRAIn - a distribution matching approach for RLHF that achieves SOTA performance on Antrophic HH and TL;DR summarization tasks outperforming DPO and other RLHF methods!

Note: This work has been accepted for publication at ICML-2024 (main conference)

Link to arxiv version:: https://arxiv.org/pdf/2402.02479

Table of Contents

The 3 phases of LLM training

In the past few years, large language models (LLMs) have demonstrated immense prowess on a wide-variety of tasks that include multi-turn conversations, creative writing, mathematical and logical reasoning etc. These large language models are often trained in 3 phases:

  • Self-supervised pretraining
  • Supervised instruction tuning (SFT)
  • Reinforcement Learning from Human feedback (RLHF)

The self-supervised pretraining phase induces language understanding and generation capabilities in the model while the supervised instruction tuning phase teaches the model to follow natural language instructions. In the RLHF phase, the model is encouraged to follow behaviours that are desirable to us. The notion of desirability could be explicit (for instance, no profanity in the output) or could be implicit in human preference of certain output text over others.

Alignment to human preferences

So, how does one go about aligning a model to human preferences? PPO-RLHF, the RLHF approach behind GPT-3.5 and GPT-4, achieves this by first training a reward model to mimic the human preferences, that is, the reward should be higher for human-preferred outputs as compared to others. Then, the LLM (also referred to as policy) is finetuned to generate outputs that have high reward as determined by the reward model. We also ensure that the aligned LLM is close to SFT LLM (the supervised instruction tuned LLM), thereby preventing it from forgetting the capabilities that it had acquired in the previous two phases.

PPO-RLHF has lately been replaced by offline contrastive techniques such as Sequence Likelihood Calibration (SLiC), Direct Preference Optimization (DPO) and its variants. These approaches train the LLM to contrast between the preferred/high reward outputs and the rejected/low reward outputs. DPO has emerged as the de-facto methods for aligning high performing models such as Zephyr, Mixtral and LLama-3.

While there are vast dissimilarities between the PPO-RLHF and DPO algorithms, both these approaches have the same final target, referred to as PPO-optimal policy. Another set of less well-known methods use distribution-matching to align the LLM to this optimal policy. Ideally, this would require sampling from the optimal policy which is known to be challenging. Hence distribution matching methods (DPG, GDC, GDC++) sample from a proposal distribution instead and weigh these samples based on their importance weights. Despite the clear intuition behind distribution matching, these methods have not been successful for alignment with human feedback.

Our contribution - BRAIn

While investigating the lack of success of distribution matching methods, we observed that the gradient estimate in distribution matching methods (GDC, GDC++) has high variance. What this means is that the update direction at every time-step varies widely depending on the outputs sampled from the LLM. This is demonstrated below for a simple toy example:

Assume that the target distribution that we are trying to reach is the standard 1D normal distribution N(0,1)\mathcal{N}(0,1). Let the current model distribution be N(1,1)\mathcal{N}(1,1) while the proposal distribution be N(θ,1)\mathcal{N}(\theta,1) where θ\theta is varied from 00 to 11. Below, we plot the variance of the gradient estimate of the different distribution matching objectives with respect to the mean parameter of the model distribution. The samples are drawn from the proposal distribution. As can be observed, the variance of gradient estimates of distribution matching methods (GDC, GDC++) is high when the proposal distribution is not the same as the target distribution.

img/png

This investigation is what motivated us to create BRAIn - Bayesian reward-conditioned Augmented Inference, that extends the distribution matching methods as follows:

  • We generalize the target distribution in PPO-RLHF, DPO and distribution matching methods by using Bayes' rule to incorporate the reward-modelling assumptions.
  • We propose a self-normalized baseline that significantly reduces the variance of the gradient estimate in distribution matching as shown in the figure above. By incorporating the self-normalized baseline, we achieve SOTA performance on TL;DR summarization and Anthropic-Helpful & Harmless response generation tasks, and establish DPO as a special case of BRAIn.

The BRAIn objective

Posterior as target

Given an input prompt xx, the different RLHF algorithms attempt to reach the target distribution pT(yx)p_T(y|x) over the set of outputs yy. This target distribution depends on 2 factors:

  • The base distribution. This is often an SFT model, referred to as pSFTp_{SFT}
  • The reward function r(x,y)r(x,y)

BRAIn uses Bayes' rule to combine the information from the above two factors. Specifically, the SFT model acts as the prior while the reward function is used to define a likelihood term. The resulting posterior is referred to as the target pT(yx)p_T(y|x).

image/png

Training with importance weights

Let qθq_\theta be the model that we wish to align to the target pT(yx)p_{T}(y|x). Ideally, one can achieve this by sampling from the target and training qθq_\theta using these samples as shown below.

image/png

However, since sampling from the target can be challenging, we use a proposal distribution q(yx)q(y|x) for sampling instead and reweigh those samples based on pT(yx)q(yx)\frac{p_T(y|x)}{q(y|x)}. Since the normalization constant of pTp_T is intractable, we self-normalize the weights as shown below:

image/png

A note on proposal distribution: What would be the ideal distribution to generate samples from? Clearly, since we are trying to reach the target, ideally, we should sample from the target pTp_T. However, since this is challenging, we choose to sample from the distribution that is closest to the target. At the beginning of our training, we sample from the SFT model pSFTp_{SFT}. However, as training proceeds, we include samples from the latest policy.

A self-normalized baseline to reduce variance

The gradient of the above objective is given by θL(θ)=i=1nα^yiθlogpθ(yix) \nabla_\theta \mathcal{L}(\theta) = \sum_{i=1}^n \hat{\alpha}_{y_i} \nabla_\theta \log p_\theta(y_i|x) This gradient estimate has been used in GDC for LLM alignment with the difference that the weights are not self-normalized. As we had shown earlier, the GDC gradient estimate has high variance which translates to poor performance.

To reduce the variance, we propose to subtract a self-normalized baseline from the above gradient estimate as shown below:

image/png

While the connection with distribution matching objective of GDC is obvious, we establish the connection with DPO in the paper.

image/png

Experimental results

We evaluate BRAIn on two tasks:

  • Summarization: We use the Reddit TL;DR dataset for this task.
  • Helpful & Harmless response generation: We use the Anthropic HH dataset for this task.

We evaluate the various models based on win-rate against gold, that is, the fraction of test samples on which the generated response is preferred over the gold response. We compute this quantity using two reward models 1) Train RM which is the reward function used for aligning the SFT model 2) LLM eval in which we prompt Mixtral 8x7B to compare the two outputs and declare a winner. The performance against other baselines are displayed in the figure below:

image/png

As can be observed, BRAIn outperforms other baselines on both the evaluation measures.

We also study the impact of self-normalized baseline subtraction on the performance. The table below lists the win-rate of BRAIn with and without self-normalized baseline. The last column of the table corresponds to baseline subtraction without self-normalization. As can be observed from the table, self-normalization is crucial for achieving reasonable performance in distribution matching.

BRAIn w/o self-norm w/o baseline
TL;DR 95.2 61.4 61.1
AnthropicHH 95.4 59.1 58.3

Other Blogs by our team

From Fiction to Fact: Making Chatbots Grounded in Reality

Acknowledgements

This work was done in collaboration with Ramón Fernandez Astudillo, Yatin Nandwani, Tahira Naseem, Mayank Mishra, Guangxuan Xu, Dinesh Raghu, Sachindra Joshi and Asim Munawar

Community

Sign up or log in to comment