UltraGist for Llama-2-7b-chat

[Paper] [Github]

UltraGist is a context compression method can flexibly, effectively, and efficiently to handle various context lengths and compression ratios. We apply UltraGist on Llama-2-7b-chat.

Usage

import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "namespace-Pt/ultragist-llama2-7b-chat"

tokenizer = AutoTokenizer.from_pretrained(
  model_id, 
  trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
  model_id, 
  trust_remote_code=True, 
  torch_dtype=torch.bfloat16, 
  attn_implementation="sdpa",
  # load the entire model on the default gpu
  device_map={"": "cuda"}, 
  # you can manually set the compression ratio, otherwise the model will automatically choose the most suitable compression ratio from [2,4,8,16,32]
  # ultragist_ratio=[8],
).eval()


with torch.no_grad():
  # long context
  with open("data/nqa.json", encoding="utf-8") as f:
    example = json.load(f)
    content = f"Read this article:\n\n{example['context']}\n\nNow, answer the question based on the above context.\nQuestion:\n{example['input']}"
  messages = [{"role": "user", "content": content}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")

  # reset memory before new compression task
  model.memory.reset()

  # directly call generate to progressively compress the context while generating next tokens
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:]
  print("*"*20)
  print(f"Input size:       {inputs['input_ids'].shape[1]}")
  print(f"Question:         {example['input']}")
  print(f"Answers:          {example['answers']}")
  print(f"Prediction:       {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
  print("*"*20)

  # extract the compressed memory (including the generated tokens)
  compressed_memory = model.memory.get_memory()
  ultragist_size, raw_size, sink_size = model.memory.get_memory_size()
  print(f"UltraGist size:   {ultragist_size}")
  print(f"Raw size:         {raw_size}")
  print(f"Sink size:        {sink_size}")
  print(f"Memory:           {compressed_memory[0][0].shape}")
  print("*"*20)
Downloads last month
155
Safetensors
Model size
8.89B params
Tensor type
BF16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.