UltraGist is a context compression method can flexibly, effectively, and efficiently to handle various context lengths and compression ratios. We apply UltraGist on Llama-2-7b-chat.
Usage
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "namespace-Pt/ultragist-llama2-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
# load the entire model on the default gpu
device_map={"": "cuda"},
# you can manually set the compression ratio, otherwise the model will automatically choose the most suitable compression ratio from [2,4,8,16,32]
# ultragist_ratio=[8],
).eval()
with torch.no_grad():
# long context
with open("data/nqa.json", encoding="utf-8") as f:
example = json.load(f)
content = f"Read this article:\n\n{example['context']}\n\nNow, answer the question based on the above context.\nQuestion:\n{example['input']}"
messages = [{"role": "user", "content": content}]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
# reset memory before new compression task
model.memory.reset()
# directly call generate to progressively compress the context while generating next tokens
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:]
print("*"*20)
print(f"Input size: {inputs['input_ids'].shape[1]}")
print(f"Question: {example['input']}")
print(f"Answers: {example['answers']}")
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
print("*"*20)
# extract the compressed memory (including the generated tokens)
compressed_memory = model.memory.get_memory()
ultragist_size, raw_size, sink_size = model.memory.get_memory_size()
print(f"UltraGist size: {ultragist_size}")
print(f"Raw size: {raw_size}")
print(f"Sink size: {sink_size}")
print(f"Memory: {compressed_memory[0][0].shape}")
print("*"*20)
- Downloads last month
- 155
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.