Utilizing hidden states
#53
by
TatonkaHF
- opened
Hello,
I want to utilize the last layer hidden state, and for this, I run the following code:
lm_out: MaskedLMOutput = self.lm(
encoder_input_ids,
encoder_attention_mask,
labels=encoder_labels,
output_hidden_states=True,
return_dict=True
)
hidden_states = lm_out.hidden_states[-1]
Without flash attention, hidden_states has three dimensions with the shape [batch_size, seq_length, hidden_size].
However, when using flash attention (if I understand it correctly), I'm getting outputs with only two dimensions, where the shape of the first item is not constant.
Is there a way to obtain the results in the shape [batch_size, seq_length, hidden_size] while using flash attention? Or should I just pad the whole sequence to [b * s] length?