Fine tuning precision task of transliteration on Gemma-2b-it model
Facing issues with finetuning gemma-2b-it model for transliteration task.
I want create a lora for transliteration task from Roman to Devnagri and vice versa. But even after multiple combination of
- Lora rank
- Lora decay
- Modules
- Lora drop out rate
etc
When I inferred the results for the same Lora using oogabooga, it is not generated desired result, rather it just repeated user content. Even though I have used similar parameters for training lora for same task on gemma-9b model, it is working fine.
Can someone help here or have some thoughts.
Hi @grishi911991 ,
There might be the below reasons for an above issue:
The Gemma-9B model has a greater capacity to understand and generate complex language patterns, potentially leading to more accurate and contextually appropriate outputs.
Example: Input Prompt:"Translate the following English sentence to French: 'The quick brown fox jumps over the lazy dog.'" Gemma-2B-IT Output:"Le rapide renard brun saute par-dessus le chien paresseux." Gemma-9B Output:"Le rapide renard brun bondit par-dessus le chien paresseux."
In this example, both models provide correct translations. However, the Gemma-9B model uses the verb "bondit" (leaps) instead of "saute" (jumps), which may be considered a more contextually appropriate choice in certain contexts.
Parameter size of **Gemma-9B model **is greater than Gemma-2B-IT model, so Gemma-9B model allows for more nuanced language understanding and generation, potentially leading to more refined outputs.
- Ensure that the dataset contains a sufficient number of training examples for a 2B parameter model. A larger dataset helps the model learn the intricate patterns within the data, leading to a deeper understanding and more accurate results.
Thank you.
@GopiUppari thanks for responding but I am working on transliteration rather than translation. Even with that, I would expect some dip in quality but in my case, it is not performing at all as expected.
I used model rank 4 & alpha 8 and in other option rank 2 & alpha 4
Lora decay 0.02, 0.05, 0.1
Lora drop out 0.1 & 0.2
This is the script, I used
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
hf_token = "xxx"
login(token = hf_token)
wandb.login(key="xxxx")
run = wandb.init(
project='Fine-tune gemma2-2b on PF',
job_type="training",
anonymous="allow"
)
model_id = "google/gemma-2-2b-it"
#quantization_config_loading = GPTQConfig(bits=8, disable_exllama=True)
dataset = load_dataset("yyyy", data_files={'train': "yyyy", 'validation': "yyy"})
max_seq_length = 2048
#model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=quantization_config_loading)
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()
#model = prepare_model_for_kbit_training(model)
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules= ["q_proj", "k_proj", "v_proj", "o_proj"],
#layers_to_transform = list(range(12, 26)),
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
import transformers
args=transformers.TrainingArguments(
per_device_train_batch_size = 10,
per_device_eval_batch_size = 5,
gradient_accumulation_steps = 1,
warmup_steps = 100,
num_train_epochs=2,
eval_strategy="steps",
eval_steps=500,
save_steps=500,
learning_rate=2e-4,
fp16=True, #use mixed precision training
logging_steps=10,
lr_scheduler_type = "cosine",
weight_decay = 0.02,
output_dir="gemma2_2b_training_hn",
report_to="wandb",
optim="adamw_hf"
)
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
peft_config=config,
dataset_text_field="text",
tokenizer=tokenizer,
packing=False,
max_seq_length=max_seq_length)
trainer.train(
Hi @grishi911991 ,
Could you please try with the below parameters:
- Increase rank (r=16) and adjust alpha if needed.
- Try lower learning rates (learning_rate = 1e-4 or 5e-5).
- Increase gradient_accumulation_steps to mitigate GPU memory issues.
- Consider more epochs (e.g., 3–5) for better convergence and check if validation metrics improve.
Thank you.