different results w/ and w/o flash-attn

#3
by g-h-chen - opened

Hi there,
I trained your model with two settings:
1. default configs
2. I used config._attn_implementation = "flash_attention_2" to enable flash_attn_2.

Under setting 2, the training speed is doubled compared to setting 1, but the loss goes high and is extremely unstable.

What is the correct way of activating flash-attn? Thanks in advance!

Hi @g-h-chen , thanks for raising the issue!

Is it possible to share the training script?
Or maybe an example script to replicate the unstable training?

@g-h-chen I got the same problem ( loss goes high) when set config._attn_implementation = "flash_attention_2". Is there any way to fix this?

Hi @zhumj34 , could you please share the training script?

I will have a look at it but it is much easier when I have the training script available :) .

Hi @g-h-chen @zhumj34 , could you please update to transformers main and try to retrain the model again?
There was a recent commit to phi model on the library which fixed an issue of fp16 logits becoming NaNs.

Please update to transformers main before re running the training script -

git clone https://github.com/huggingface/transformers.git
cd transformers
pip install .

Please let me know if this update fixes the issue that you are facing.

Hi, just a heads up, because of the recent changes made to main branch of transformers you can't load the model properly because of the weights mismatch(using this repo). I will most likely fix this today.

Hi @zhumj34 , could you please share the training script?

I will have a look at it but it is much easier when I have the training script available :) .

Hi @susnato , sorry for the late reply. My code is based on llava. I just simply replace llama with phi-2, and load llava-phi with flash_attention_2 (similar to this code, https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/phi#combining-phi-and-flash-attention-2). The training script is identically to llava-v1.5.

Hi @g-h-chen @zhumj34 , could you please update to transformers main and try to retrain the model again?
There was a recent commit to phi model on the library which fixed an issue of fp16 logits becoming NaNs.

Please update to transformers main before re running the training script -

git clone https://github.com/huggingface/transformers.git
cd transformers
pip install .

Please let me know if this update fixes the issue that you are facing.

OK. About two weeks ago, I finetuned llava-phi with llava-v1.5 fine-tuning script. After that , at the inference stage, I load the model with fp16 and the output texts are '!!!!! ... !!!!!', which are caused by NaNs. At that time, I thought I'd made a mistake somewhere.

I'll try this update to verify if it solves the NaNs problem. Sorry for the later reply.

I notice that I pretrain and finetune llava-phi with bfp16. Is there anything wrong with this setup? I'll retrain the model with this version of transformers.

image.png

After such an update and load model with fp16, the model's output is not '!!!!! ... !!!!!'. I think NaNs problem has been solved nicely. But I can't load the model properly because of the weights mismatch.

Hi @zhumj34 , I apologise for the huge delay.
I have updated the checkpoint and it should work now. Please install the latest transformers version 4.38.0.dev0 by running this command -

pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers

After updating the library, you should be able to properly load the weights.

Please let me know if you are facing any issues with it.

Hi @g-h-chen , regarding the problem you are facing

Under setting 2, the training speed is doubled compared to setting 1, but the loss goes high and is extremely unstable.

There is an ongoing issue at transformers where people are reporting the same thing as you said above. I guess Gugarosa is working on fixing it.
In the meantime please try to fine-tune it without FA2 (as I have updated the checkpoint you can properly load the weights), sorry for the inconvenience.

Sign up or log in to comment