Utilizing hidden states

#53
by TatonkaHF - opened

Hello,
I want to utilize the last layer hidden state, and for this, I run the following code:

lm_out: MaskedLMOutput = self.lm(
    encoder_input_ids,
    encoder_attention_mask,
    labels=encoder_labels,
    output_hidden_states=True,
    return_dict=True
)

hidden_states = lm_out.hidden_states[-1]

Without flash attention, hidden_states has three dimensions with the shape [batch_size, seq_length, hidden_size].
However, when using flash attention (if I understand it correctly), I'm getting outputs with only two dimensions, where the shape of the first item is not constant.
Is there a way to obtain the results in the shape [batch_size, seq_length, hidden_size] while using flash attention? Or should I just pad the whole sequence to [b * s] length?

Sign up or log in to comment