IndicTrans2
Collection
Models(En-Indic, Indic-En, Indic-Indic) in 2 variants (base and dist) and Benchmarks (IN22-Gen and IN22-Conv) released as a part of IndicTrans2.
•
8 items
•
Updated
•
10
This is the model card of IndicTrans2 Indic-En 1.1B variant.
Here are the metrics for the particular checkpoint.
Please refer to Appendix D: Model Card
of the preprint for further details on model training, intended use, data, metrics, limitations and recommendations.
Please refer to the github repository for a detail description on how to use HF compatible IndicTrans2 models for inference.
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
# recommended to run this on a gpu with flash_attn installed
# don't set attn_implemetation if you don't have flash_attn
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
src_lang, tgt_lang = "hin_Deva", "eng_Latn"
model_name = "ai4bharat/indictrans2-indic-en-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16, # performance might slightly vary for bfloat16
attn_implementation="flash_attention_2"
).to(DEVICE)
ip = IndicProcessor(inference=True)
input_sentences = [
"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
]
batch = ip.preprocess_batch(
input_sentences,
src_lang=src_lang,
tgt_lang=tgt_lang,
)
# Tokenize the sentences and generate input encodings
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations using the model
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# Postprocess the translations, including entity replacement
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
for input_sentence, translation in zip(input_sentences, translations):
print(f"{src_lang}: {input_sentence}")
print(f"{tgt_lang}: {translation}")
model_name
parameter. Please read the model card of the RoPE-IT2 models for more information about the generation.flash_attention_2
for efficient generation.If you consider using our work then please cite using:
@article{gala2023indictrans,
title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=vfT4YuzAYA},
note={}
}