Improving performance with Arena Learning in post training

Community Article Published September 11, 2024

The Problem

Effectiveness of chat LLMs relies primarily on high-quality instruction-following data used in post-training, enabling them to communicate effectively with humans. The challenge however lies in curating high-quality instruction data & effectively evaluating them.

Existing methods utilize platforms like the LMSYS Chatbot Arena for evaluation, which put different chatbot models against each other in conversational challenges, judged by human evaluators. While this method proves to provide robust and comprehensive evaluations, it is a resource as well as time intensive approach and limits the scalability of model improvements due to its unavoidable dependency on humans.

Lmsys Model Comparison

Fig. Lmsys - Model Comparison

Also, due to their priority limitations, most models cannot participate in the arena evaluations. This ultimately raises the need for a more efficient and scalable arena-based pipeline which would help the LLM in post-training and evaluation.

The Solution

ArenaLearning, a novel technique, which mitigates manual and temporal costs associated with post-trainings by introducing an automated training and evaluation pipeline. Some advantages to list are:

  • Iterative training.
  • Automated evaluation - without humans in loop. Uses "judge model", which can automatically imitate the human annotators in judging a response pair of two models and correspondingly provide rankings, scores, and explanation.

In this blog, we attempt to replicate the model training data curation process & share relevant scripts for it. The complete code is available in colab notebook.

How it's done

image/png (Source: arena-learning-build-data-flywheel-for-llms-post-training-via-simulated-chatbot-arena)

In the post-training scenario, as shown in the figure, Arena Learning simulates battles among the target model (here referred to as WizardLM-β) and various state-of-the-art models on a large scale of instruction data. These synthetic battle results are then used to enhance the target model WizardLM-β through some training strategies.

The WizardLM-β is continuously updated and re-evaluated against SOTA models.

Implementation details

The judge LLM

To assess the response quality of each LLM using the the judge LLM, prompt engineering with the Llama3-70B-Chat model is done with inputs as dialogue history, user instruction, and the responses of two LLMs. The outputs are the scores for each LLM, along with explanations focused on factors, such as coherence, factual accuracy, context-awareness, and overall quality, to determine superiority amongst the model responses. To further overcome potential position bias, a two-game setup, alternating the positions of the two LLMs is employed.

Collecting large-scale instruction training dataset for Arena Learning

A large-scale corpus of conversational data (D) is required to train WizardLM-β. The initial model is first trained on a randomly sampled 10k ShareGPT data. Some instructions from the following openly available datasets are collected:

  • WizardLM
  • Stanford Alpaca
  • Stack exchange preferences
  • LMSYS Chat
  • Flan Dataset
  • Open orca

The collected instructions are further optimized with the following steps:

  1. Filter out all illegal and toxic conversations by using LLM(s) to classify.

Here is the sample code, to do this with the dataformer. Dataformer allows multiple asynchronous requests while respecting rate-limits of different api providers & leveraging cache.

#install library 
pip install dataformer
from dataformer.llms.asyncllm import AsyncLLM


def generate_data(data, api_provider,model_name,api_key,max_requests_per_minute,max_tokens_per_minute,max_concurrent_requests):
  llm = AsyncLLM(api_provider=api_provider, model=model_name, api_key=api_key, max_requests_per_minute=max_requests_per_minute, max_tokens_per_minute=max_tokens_per_minute,max_concurrent_requests=max_concurrent_requests)


  request_list = []
  text=[]
  #iterate over data and bring data in proper format for llm
  for passage in data:
    prompt = "Categorize the text as 'USE' or 'DONT' based on whether the text contains any illegal or toxic language or references. if it does contain toxic or illegal content, label it as 'DONT' category else label it as 'USE' Category. The text is as follows:\n\n"
    answer_prompt_json="["
    for j in passage:
      answer_prompt_json = answer_prompt_json+'''
      {'''+f'''
          "role" : "{j['role']}",
          "content" : "{j['content']}"
          '''+'''
      }

      '''
    answer_prompt_json = answer_prompt_json +"]"
     #model input length, ignore or change if required
    if len(answer_prompt_json)>8192:
      continue

    user_text = prompt + answer_prompt_json + "\n\n" +"""Give results only in categories 'USE' or 'DONT', dont give any other content
"""
    data_dict = {
        "messages": [{"role": "user", "content": user_text}],
        # "temperature": 0.7
    }

    request_list.append(data_dict)
    text.append(passage)

  response_list = llm.generate(request_list)
  #collect answers as USE and DONT category and later take only text or records with use category
  answers = []
  for row in response_list:
    try:
          answers.append(row[-1]['choices'][0]['message']['content'])
    except Exception:
      print(row)

  return text,answers
  1. Remove conversations with instruction lengths of less than 10. Here, is the code to perform the same.
token_len_filtered=[]
for row in filtered_data:
  new_tp=[]
  for convo_id in range(len(row)):
    break_bool=False
 #check only instructions given by user
    if row[convo_id]['role'] == 'user':
      if len(row[convo_id]['content'])>10:
            new_tp.append(row[convo_id])
            new_tp.append(row[convo_id+1])
      else:
        break_bool = True
        print(row[convo_id])
    if break_bool:
      #skip any conversation fully if any dict in conversation is <10 token
      new_tp=[]
      break

  if len(new_tp)!=0:
    token_len_filtered.append(new_tp)
#Filtered answer saved in token_len_filtered, use this instructions records only and extract conversations with this instructions with for/while loop further
print(token_len_filtered[0],len(token_len_filtered))
  1. Eliminate duplicate instructions with prefixes of 10 (The MinHashLSH technique). Here, is a sample code to do this with datatrove:
import os
import gc
import logging
from datatrove.pipeline.dedup import MinhashDedupSignature
from datatrove.pipeline.dedup.minhash import (
    MinhashConfig,
    MinhashDedupBuckets,
    MinhashDedupCluster,
    MinhashDedupFilter,
)
from datatrove.pipeline.readers import JsonlReader
from datatrove.pipeline.tokens import TokensCounter
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.executor import LocalPipelineExecutor

# Configure logging
logging.basicConfig(level=logging.INFO)

# Configuration for Minhash
minhash_config = MinhashConfig(use_64bit_hashes=True)  # better precision -> fewer false positives (collisions)

# Paths for local data
LOCAL_DATA_PATH = "/content/sample_data/input_deduplicate" #put jsonl file here in this folder
LOCAL_MINHASH_BASE_PATH = "/content/sample_data/minhash_deduplicate"
LOCAL_LOGS_FOLDER = "/content/sample_data/log_deduplicate"

# Ensure output directories exist
os.makedirs(LOCAL_MINHASH_BASE_PATH, exist_ok=True)
os.makedirs(LOCAL_LOGS_FOLDER, exist_ok=True)

# Total tasks for local execution
TOTAL_TASKS = 50

# This is the original data that we want to deduplicate
INPUT_READER = JsonlReader(LOCAL_DATA_PATH)

# Stage 1: Compute minhash signatures for each task
stage1 = LocalPipelineExecutor(
    pipeline=[
        INPUT_READER,
        MinhashDedupSignature(output_folder=f"{LOCAL_MINHASH_BASE_PATH}/signatures", config=minhash_config),
    ],
    tasks=TOTAL_TASKS,
    logging_dir=f"{LOCAL_LOGS_FOLDER}/signatures",
)

# Run Stage 1 and collect garbage
try:
    logging.info("Running Stage 1: Minhash Signatures")
    stage1.run()
    gc.collect()
except Exception as e:
    logging.error(f"Stage 1 failed: {e}")

# Stage 2: Find matches between signatures in each bucket
stage2 = LocalPipelineExecutor(
    pipeline=[
        MinhashDedupBuckets(
            input_folder=f"{LOCAL_MINHASH_BASE_PATH}/signatures",
            output_folder=f"{LOCAL_MINHASH_BASE_PATH}/buckets",
            config=minhash_config,
        ),
    ],
    tasks=minhash_config.num_buckets,
    logging_dir=f"{LOCAL_LOGS_FOLDER}/buckets",
    depends=stage1,
)

# Run Stage 2 and collect garbage
try:
    logging.info("Running Stage 2: Minhash Buckets")
    stage2.run()
    gc.collect()
except Exception as e:
    logging.error(f"Stage 2 failed: {e}")

# Stage 3: Create clusters of duplicates using the results from all buckets
stage3 = LocalPipelineExecutor(
    pipeline=[
        MinhashDedupCluster(
            input_folder=f"{LOCAL_MINHASH_BASE_PATH}/buckets",
            output_folder=f"{LOCAL_MINHASH_BASE_PATH}/remove_ids",
            config=minhash_config,
        ),
    ],
    tasks=1,
    logging_dir=f"{LOCAL_LOGS_FOLDER}/clusters",
    depends=stage2,
)

# Run Stage 3 and collect garbage
try:
    logging.info("Running Stage 3: Minhash Clusters")
    stage3.run()
    gc.collect()
except Exception as e:
    logging.error(f"Stage 3 failed: {e}")

# Stage 4: Read the original input data and remove all but 1 sample per duplicate cluster
deduplicated_output_folder = f"{LOCAL_MINHASH_BASE_PATH}/deduplicated_output"
os.makedirs(deduplicated_output_folder, exist_ok=True)

stage4 = LocalPipelineExecutor(
    pipeline=[
        INPUT_READER,
        TokensCounter(),  # See how many tokens we had before and after deduplication
        MinhashDedupFilter(
            input_folder=f"{LOCAL_MINHASH_BASE_PATH}/remove_ids",
            exclusion_writer=JsonlWriter(f"{LOCAL_MINHASH_BASE_PATH}/removed"),
        ),
        JsonlWriter(output_folder=deduplicated_output_folder),
    ],
    tasks=TOTAL_TASKS,
    logging_dir=f"{LOCAL_LOGS_FOLDER}/filter",
    depends=stage3,
)

# Execute the final stage
try:
    logging.info("Running Stage 4: Deduplication and Writing Output")
    stage4.run()
    gc.collect()
except Exception as e:
    logging.error(f"Stage 4 failed: {e}")

# Verify the output
if not os.listdir(deduplicated_output_folder):
    logging.error("Deduplicated output folder is empty.")
else:
    logging.info("Deduplicated output has been successfully written.")
  1. To prevent test data leakage, use embedding model gte-large and exclude 5 semantically similar instructions from the following benchmarks:
  • WizardArena
  • Arena-Hard Auto
  • MT Bench
  • AlpacaEval
  • OpenLLM Leaderboard

Here, a sample code with the dataformer library mentioned earlier, before executing this code, create a coherent dataset by merging all the benchmarks data mentioned above.

import requests
import os
import gc
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
import torch
import numpy as np
import json
from dataformer.llms import AsyncLLM

def cosine_similarity(embeddings1, embeddings2):
    return torch.nn.functional.cosine_similarity(embeddings1, embeddings2)

def get_similarity(url,train_data,evaluation_data,model,api_key,max_requests_per_minute):
    filtered_data=[]
    similarities = []
    try:

      llm = AsyncLLM(base_url=url,
                     model=model,api_key=api_key,max_requests_per_minute=max_requests_per_minute)

      data=[]
      eval_data=[]
      # Send the POST request with error handling
      batchsize_one_req=200
      t=[]
      print(train_data[0])
      for i in train_data:
        if len(t)>=batchsize_one_req or i['text']==train_data[-1]['text']:


          data.append({"input":t})
          t=[]
        else:
          t.append(str(i['text']))
      t=[]
      for i in evaluation_data:
        if len(t)>=batchsize_one_req or i==evaluation_data[-1]:
          eval_data.append({"input":t})
          t=[]
        else:
          t.append(str(i))

      embeddings_train = []
      # Compute embeddings for the data in batches
      embeddings_test = []

      response_list_train = llm.generate(data)
      for one_res in response_list_train:
        for j in one_res[1]['data']:
          embeddings_train.append(torch.tensor(j['embedding']))
      embeddings_train = torch.stack(embeddings_train)

      response_list_test = llm.generate(data)
      #print(response_list_test)
      for one_res in response_list_train:
        for j in one_res[1]['data']:
          embeddings_test.append(torch.tensor(j['embedding']))
      embeddings_test = torch.stack(embeddings_test)
      print("all embeddings done")

      for i in range(len(embeddings_train)):
        for j in range(i+1, len(embeddings_test)):
          sim = cosine_similarity(embeddings_train[i].view(1,-1), embeddings_test[j].view(1,-1)).item()
          similarities.append((sim, i, j))

      # Sort similarities and exclude top 5 matches
      similarities.sort(reverse=True, key=lambda x: x[0])
      excluded_indices = set()
      for _, i, j in similarities:
          if i == j:
            continue
          if len(excluded_indices)==5:
            break
          excluded_indices.add(i)

      print("similarities calculated")
      # # Filter out the excluded indices
      filtered_data = [item for idx, item in enumerate(train_data) if idx not in excluded_indices]

      # # Save the filtered data as JSONL
      filtered_data_path = "/content/sample_data"
      with open("/content/sample_data/filtered_data_excluded.jsonl", 'w') as f:
          for item in filtered_data:
              f.write(json.dumps(item) + "\n")

      print("Done data in  /content/sample_data/filtered_data_excluded.jsonl")
    except requests.exceptions.RequestException as e:
      print(f"An error occurred: {e}")
    return filtered_data,similarities,excluded_indices
  1. Filter language if required. (Can be done with lang detect or fasttext-langdetect libraries)
    After completing these steps, a refined 276K dataset D, is randomly split into 9 parts.
    Further, the simulated arena battle outcomes will be used to generate training data for the WizardLM-β, tailored to different training strategies: supervised fine-tuining (SFT), direct preference optimization (DPO), and proximal policy optimization (PPO).
    The data equally is splitted into parts like D = {D0, D1, D2, ..., DN } for following iterative training and updates.

Iterative Battle and Model Evolving

Arena Learning uses an iterative process for training and improving WizardLM-β:

  1. Train initial version WizardLM-β-SFT-I0 with D0
  2. Select top-ranking SOTA models M from WizardArena test set
  3. Battle WizardLM-β-SFT-I0 with M on D1
  4. Extract instances where WizardLM-β's response is inferior
  5. Use winning model's response as target output for fine-tuning WizardLM-β-SFT-I1
  6. For DPO: Battle WizardLM-β-SFT-I1 with M on D2, treat win/loss responses as <choice, reject> pairs
  7. For PPO: Battle WizardLM-β-DPO-I1 with M on D3, obtain <choice, reject> pairs
  8. In second iteration I2, select best WizardLM-β-PPO-I1 as initial competitor
  9. Repeat process to train next SFT, DPO, and PPO models

Test Data Generation

The test data comprises of two subsets:

Diverse Subset

  • Captures broad range of topics, styles, and conversational contexts
  • Uses text clustering techniques with ~500 categories
  • Employs state-of-the-art embedding models (e.g., gte-large)
  • Selects 2 representative samples from each cluster, resulting in 1000 records

Hard Subset

  • Designed for complex and challenging scenarios
  • Selects 10000 records from 500 categories randomly
  • Uses GPT-4-1106-preview to rate difficulty (0-10 scale)
  • Selects top 1000 entries for the hard test set

Limitations

  • Potential failure in accurately imitating human evaluators by the judge model
  • Risk of generating unethical or misleading information

Conclusion

Arena Learning offers a cost-effective and reliable alternative to traditional human-based evaluation systems. It progressively enhances and scales the capabilities of large language models, providing an effective way to improve the post-training process while mitigating costs.

References

Community

Sign up or log in to comment