xwwu commited on
Commit
26c2f02
·
verified ·
1 Parent(s): 5e54ccb

Upload folder using huggingface_hub

Browse files
configuration_hformer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class HformerConfig(PretrainedConfig):
4
+ model_type = 'hformer'
5
+ _auto_class = 'AutoConfig'
6
+
7
+ def __init__(
8
+ self,
9
+ num_query_token=32,
10
+ visual_hidden_size=4096,
11
+ llm_hidden_size=768,
12
+ cross_attention_freq=2,
13
+ bert="bert-base-uncased",
14
+ bias=True,
15
+ qformer_pth=None,
16
+ **kwargs,
17
+ ):
18
+ self.num_query_token=num_query_token
19
+ self.visual_hidden_size = visual_hidden_size
20
+ self.llm_hidden_size = llm_hidden_size
21
+ self.bias = bias
22
+ self.bert = bert
23
+ self.cross_attention_freq = cross_attention_freq
24
+ self.qformer_pth = qformer_pth
25
+ super().__init__(**kwargs)
configuration_projector.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ProjectorConfig(PretrainedConfig):
4
+ model_type = 'projector'
5
+ _auto_class = 'AutoConfig'
6
+
7
+ def __init__(
8
+ self,
9
+ visual_hidden_size=4096,
10
+ llm_hidden_size=4096,
11
+ depth=2,
12
+ hidden_act='gelu',
13
+ bias=True,
14
+ **kwargs,
15
+ ):
16
+ self.visual_hidden_size = visual_hidden_size
17
+ self.llm_hidden_size = llm_hidden_size
18
+ self.depth = depth
19
+ self.hidden_act = hidden_act
20
+ self.bias = bias
21
+ super().__init__(**kwargs)
llm/added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "<|endoftext|>": 64001,
3
+ "<|startoftext|>": 64000
4
+ }
llm/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/export/share/models/Yi-6B-Chat",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 4096,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 11008,
14
+ "max_position_embeddings": 4096,
15
+ "model_type": "llama",
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 32,
18
+ "num_key_value_heads": 4,
19
+ "pretraining_tp": 1,
20
+ "rms_norm_eps": 1e-05,
21
+ "rope_scaling": null,
22
+ "rope_theta": 5000000.0,
23
+ "tie_word_embeddings": false,
24
+ "torch_dtype": "float16",
25
+ "transformers_version": "4.37.0",
26
+ "use_cache": true,
27
+ "vocab_size": 64000
28
+ }
llm/generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 6,
3
+ "do_sample": true,
4
+ "eos_token_id": 7,
5
+ "max_length": 4096,
6
+ "pad_token_id": 0,
7
+ "temperature": 0.6,
8
+ "top_p": 0.8,
9
+ "transformers_version": "4.37.0"
10
+ }
llm/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5495e9968c760074c2d0cdd45991fdabe3c23f6574a693755292559783e5b31f
3
+ size 9943068568
llm/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ede952f47385073681c0fcca29d371a413848440bfdf15c09889eede9ccdaf30
3
+ size 2179035976
llm/model.safetensors.index.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 12122071040
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00001-of-00002.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
179
+ "model.layers.26.input_layernorm.weight": "model-00001-of-00002.safetensors",
180
+ "model.layers.26.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
181
+ "model.layers.26.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
182
+ "model.layers.26.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
183
+ "model.layers.26.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
184
+ "model.layers.26.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
185
+ "model.layers.26.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
186
+ "model.layers.26.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
187
+ "model.layers.26.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
188
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
189
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
190
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
191
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
192
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
193
+ "model.layers.27.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
194
+ "model.layers.27.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
195
+ "model.layers.27.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
196
+ "model.layers.27.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
197
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
198
+ "model.layers.28.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.28.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
200
+ "model.layers.28.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
201
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.28.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
203
+ "model.layers.28.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
204
+ "model.layers.28.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
205
+ "model.layers.28.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
206
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
207
+ "model.layers.29.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
208
+ "model.layers.29.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
209
+ "model.layers.29.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
210
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
211
+ "model.layers.29.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
212
+ "model.layers.29.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
213
+ "model.layers.29.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
214
+ "model.layers.29.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
215
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
216
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
217
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
218
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
219
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
220
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
221
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
222
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
224
+ "model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
225
+ "model.layers.30.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
226
+ "model.layers.30.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
227
+ "model.layers.30.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
228
+ "model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
229
+ "model.layers.30.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
230
+ "model.layers.30.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
231
+ "model.layers.30.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
232
+ "model.layers.30.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
233
+ "model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
234
+ "model.layers.31.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
235
+ "model.layers.31.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
236
+ "model.layers.31.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
237
+ "model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
238
+ "model.layers.31.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
239
+ "model.layers.31.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
240
+ "model.layers.31.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
241
+ "model.layers.31.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
242
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
243
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
244
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
246
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
247
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
248
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
249
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
250
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
251
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
252
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
253
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
254
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
255
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
256
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
257
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
258
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
260
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
264
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
265
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
266
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
268
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
269
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
270
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
271
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
272
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
274
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
275
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
276
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
277
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
278
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
279
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
280
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
281
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
282
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
283
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
284
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
285
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
286
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
287
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
288
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
289
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
290
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
291
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
292
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
293
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
294
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
295
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
296
+ "model.norm.weight": "model-00002-of-00002.safetensors"
297
+ }
298
+ }
llm/special_tokens_map.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|im_sep|>"
6
+ ],
7
+ "bos_token": {
8
+ "content": "<|startoftext|>",
9
+ "lstrip": false,
10
+ "normalized": true,
11
+ "rstrip": false,
12
+ "single_word": false
13
+ },
14
+ "eos_token": {
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "pad_token": {
22
+ "content": "<unk>",
23
+ "lstrip": false,
24
+ "normalized": true,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "unk_token": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ }
35
+ }
llm/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
llm/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:386c49cf943d71aa110361135338c50e38beeff0a66593480421f37b319e1a39
3
+ size 1033105
llm/tokenizer_config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "6": {
30
+ "content": "<|im_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "7": {
38
+ "content": "<|im_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "8": {
46
+ "content": "<|im_sep|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "64000": {
54
+ "content": "<|startoftext|>",
55
+ "lstrip": false,
56
+ "normalized": true,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "64001": {
62
+ "content": "<|endoftext|>",
63
+ "lstrip": false,
64
+ "normalized": true,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ }
69
+ },
70
+ "additional_special_tokens": [
71
+ "<|im_start|>",
72
+ "<|im_end|>",
73
+ "<|im_sep|>"
74
+ ],
75
+ "bos_token": "<|startoftext|>",
76
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
77
+ "clean_up_tokenization_spaces": false,
78
+ "encode_special_tokens": true,
79
+ "eos_token": "<|endoftext|>",
80
+ "legacy": true,
81
+ "model_max_length": 4096,
82
+ "pad_token": "<unk>",
83
+ "padding_side": "right",
84
+ "sp_model_kwargs": {},
85
+ "spaces_between_special_tokens": false,
86
+ "tokenizer_class": "LlamaTokenizer",
87
+ "unk_token": "<unk>",
88
+ "use_default_system_prompt": true
89
+ }
modeling_hformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.manual_seed(1024)
3
+
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel
6
+
7
+ from .configuration_hformer import HformerConfig
8
+ from .qformer_src import BertConfig, BertLMHeadModel
9
+
10
+ from transformers import BertTokenizerFast as BertTokenizer
11
+
12
+ from .configuration_projector import ProjectorConfig
13
+ from .modeling_projector import ProjectorModel
14
+ import torch.nn.functional as F
15
+ from transformers.activations import ACT2FN
16
+
17
+
18
+ class LayerNorm(nn.LayerNorm):
19
+ def forward(self, x: torch.Tensor):
20
+ ret = super().forward(x)
21
+ return ret
22
+
23
+ class HformerModel(PreTrainedModel):
24
+ _auto_class = 'AutoModel'
25
+ config_class = HformerConfig
26
+ base_model_prefix = 'model'
27
+ supports_gradient_checkpointing = False
28
+
29
+ def __init__(self, config) -> None:
30
+ super().__init__(config)
31
+ self.gradient_checkpointing = False
32
+ vision_width = config.visual_hidden_size
33
+ num_query_token = config.num_query_token
34
+ bert = config.bert
35
+ llm_hidden_size = config.llm_hidden_size
36
+ cross_attention_freq = config.cross_attention_freq
37
+ qformer_pth = config.qformer_pth
38
+
39
+ encoder_config = BertConfig.from_pretrained(bert)
40
+ encoder_config.encoder_width = vision_width
41
+ encoder_config.add_cross_attention = True
42
+ encoder_config.cross_attention_freq = cross_attention_freq
43
+ encoder_config.query_length = num_query_token
44
+ encoder_config.num_hidden_layers = 12
45
+ Qformer = BertLMHeadModel.from_pretrained(
46
+ bert, config=encoder_config
47
+ )
48
+ remove_text = False
49
+ if remove_text:
50
+ Qformer.cls = None
51
+ Qformer.bert.embeddings.word_embeddings = None
52
+ Qformer.bert.embeddings.position_embeddings = None
53
+ for layer in Qformer.bert.encoder.layer:
54
+ layer.output = None
55
+ layer.intermediate = None
56
+
57
+ query_tokens = nn.Parameter(
58
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
59
+ )
60
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
61
+
62
+ self.Qformer = Qformer
63
+ self.query_tokens = query_tokens
64
+ self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
65
+ self.ln_vision = LayerNorm(encoder_config.encoder_width)
66
+ self.ln_llava = LayerNorm(encoder_config.encoder_width)
67
+
68
+ tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
69
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
70
+ self.Qformer.resize_token_embeddings(len(tokenizer))
71
+
72
+ if qformer_pth is not None:
73
+ pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
74
+ print(f'Load Qformer from {qformer_pth}')
75
+ self.load_state_dict(pretrained_state_dict, strict=False)
76
+ print('Done.')
77
+
78
+ projector_config = ProjectorConfig(
79
+ visual_hidden_size = config.visual_hidden_size,
80
+ llm_hidden_size = config.llm_hidden_size,
81
+ projector_depth = 2)
82
+ self.connector = ProjectorModel(projector_config)
83
+
84
+ modules = [
85
+ nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
86
+ ACT2FN['gelu'],
87
+ nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
88
+ ]
89
+ self.ffn = nn.Sequential(*modules)
90
+
91
+ def enable_input_require_grads(self):
92
+ def make_inputs_require_grad(module, input, output):
93
+ if isinstance(output, tuple):
94
+ output[0].requires_grad_(True)
95
+ output[1].requires_grad_(True)
96
+ else:
97
+ output.requires_grad_(True)
98
+
99
+ self.Qformer.register_forward_hook(make_inputs_require_grad)
100
+ self.llm_proj.register_forward_hook(make_inputs_require_grad)
101
+ self.ln_vision.register_forward_hook(make_inputs_require_grad)
102
+ self.connector.register_forward_hook(make_inputs_require_grad)
103
+ self.ffn.register_forward_hook(make_inputs_require_grad)
104
+
105
+ def _set_gradient_checkpointing(self, module, value=False):
106
+ pass
107
+
108
+ def forward(self, x_):
109
+ if self.gradient_checkpointing and self.training:
110
+ print('Not support gradient checkpointing')
111
+ x = self.ln_vision(x_)
112
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
113
+ query_output = self.Qformer.bert(
114
+ query_embeds=query_tokens,
115
+ encoder_hidden_states=x,
116
+ return_dict=True,
117
+ )
118
+
119
+ q_feat = self.llm_proj(query_output.last_hidden_state)
120
+
121
+ mlp_outputs = self.connector(x_)
122
+ mlp_feat = mlp_outputs
123
+
124
+ int_feat = mlp_feat + q_feat.mean(dim=1)[:,None]
125
+ out = int_feat + self.ffn(int_feat)
126
+
127
+ return out
128
+
modeling_projector.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.activations import ACT2FN
5
+
6
+ from .configuration_projector import ProjectorConfig
7
+
8
+
9
+ class ProjectorModel(PreTrainedModel):
10
+ _auto_class = 'AutoModel'
11
+ config_class = ProjectorConfig
12
+ base_model_prefix = 'model'
13
+ supports_gradient_checkpointing = True
14
+
15
+ def __init__(self, config: ProjectorConfig) -> None:
16
+ super().__init__(config)
17
+ self.gradient_checkpointing = False
18
+
19
+ modules = [
20
+ nn.Linear(
21
+ config.visual_hidden_size,
22
+ config.llm_hidden_size,
23
+ bias=config.bias)
24
+ ]
25
+ for _ in range(1, config.depth):
26
+ modules.append(ACT2FN[config.hidden_act])
27
+ modules.append(
28
+ nn.Linear(
29
+ config.llm_hidden_size,
30
+ config.llm_hidden_size,
31
+ bias=config.bias))
32
+ self.model = nn.Sequential(*modules)
33
+
34
+ def enable_input_require_grads(self):
35
+
36
+ def make_inputs_require_grad(module, input, output):
37
+ output.requires_grad_(True)
38
+
39
+ self.model.register_forward_hook(make_inputs_require_grad)
40
+
41
+ def _set_gradient_checkpointing(self, module, value=False):
42
+ if isinstance(module, ProjectorModel):
43
+ module.gradient_checkpointing = value
44
+
45
+ def forward(self, x):
46
+ if self.gradient_checkpointing and self.training:
47
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
48
+ else:
49
+ layer_outputs = self.model(x)
50
+ return layer_outputs
projector/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "HformerModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_hformer.HformerConfig",
8
+ "AutoModel": "modeling_hformer.HformerModel"
9
+ },
10
+ "bert": "bert-base-uncased",
11
+ "bias": true,
12
+ "cross_attention_freq": 2,
13
+ "llm_hidden_size": 4096,
14
+ "model_type": "hformer",
15
+ "num_query_token": 32,
16
+ "qformer_pth": null,
17
+ "torch_dtype": "float16",
18
+ "transformers_version": "4.37.0",
19
+ "visual_hidden_size": 1024
20
+ }
projector/configuration_hformer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class HformerConfig(PretrainedConfig):
4
+ model_type = 'hformer'
5
+ _auto_class = 'AutoConfig'
6
+
7
+ def __init__(
8
+ self,
9
+ num_query_token=32,
10
+ visual_hidden_size=4096,
11
+ llm_hidden_size=768,
12
+ cross_attention_freq=2,
13
+ bert="bert-base-uncased",
14
+ bias=True,
15
+ qformer_pth=None,
16
+ **kwargs,
17
+ ):
18
+ self.num_query_token=num_query_token
19
+ self.visual_hidden_size = visual_hidden_size
20
+ self.llm_hidden_size = llm_hidden_size
21
+ self.bias = bias
22
+ self.bert = bert
23
+ self.cross_attention_freq = cross_attention_freq
24
+ self.qformer_pth = qformer_pth
25
+ super().__init__(**kwargs)
projector/configuration_projector.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ProjectorConfig(PretrainedConfig):
4
+ model_type = 'projector'
5
+ _auto_class = 'AutoConfig'
6
+
7
+ def __init__(
8
+ self,
9
+ visual_hidden_size=4096,
10
+ llm_hidden_size=4096,
11
+ depth=2,
12
+ hidden_act='gelu',
13
+ bias=True,
14
+ **kwargs,
15
+ ):
16
+ self.visual_hidden_size = visual_hidden_size
17
+ self.llm_hidden_size = llm_hidden_size
18
+ self.depth = depth
19
+ self.hidden_act = hidden_act
20
+ self.bias = bias
21
+ super().__init__(**kwargs)
projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6fc4a5475e1ee9bc0b4d4d077c509b511aa3d4829f8e539e0ad0b713f465e94
3
+ size 430629054
projector/modeling_hformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.manual_seed(1024)
3
+
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel
6
+
7
+ from .configuration_hformer import HformerConfig
8
+ from .qformer_src import BertConfig, BertLMHeadModel
9
+
10
+ from transformers import BertTokenizerFast as BertTokenizer
11
+
12
+ from .configuration_projector import ProjectorConfig
13
+ from .modeling_projector import ProjectorModel
14
+ import torch.nn.functional as F
15
+ from transformers.activations import ACT2FN
16
+
17
+
18
+ class LayerNorm(nn.LayerNorm):
19
+ def forward(self, x: torch.Tensor):
20
+ ret = super().forward(x)
21
+ return ret
22
+
23
+ class HformerModel(PreTrainedModel):
24
+ _auto_class = 'AutoModel'
25
+ config_class = HformerConfig
26
+ base_model_prefix = 'model'
27
+ supports_gradient_checkpointing = False
28
+
29
+ def __init__(self, config) -> None:
30
+ super().__init__(config)
31
+ self.gradient_checkpointing = False
32
+ vision_width = config.visual_hidden_size
33
+ num_query_token = config.num_query_token
34
+ bert = config.bert
35
+ llm_hidden_size = config.llm_hidden_size
36
+ cross_attention_freq = config.cross_attention_freq
37
+ qformer_pth = config.qformer_pth
38
+
39
+ encoder_config = BertConfig.from_pretrained(bert)
40
+ encoder_config.encoder_width = vision_width
41
+ encoder_config.add_cross_attention = True
42
+ encoder_config.cross_attention_freq = cross_attention_freq
43
+ encoder_config.query_length = num_query_token
44
+ encoder_config.num_hidden_layers = 12
45
+ Qformer = BertLMHeadModel.from_pretrained(
46
+ bert, config=encoder_config
47
+ )
48
+ remove_text = False
49
+ if remove_text:
50
+ Qformer.cls = None
51
+ Qformer.bert.embeddings.word_embeddings = None
52
+ Qformer.bert.embeddings.position_embeddings = None
53
+ for layer in Qformer.bert.encoder.layer:
54
+ layer.output = None
55
+ layer.intermediate = None
56
+
57
+ query_tokens = nn.Parameter(
58
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
59
+ )
60
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
61
+
62
+ self.Qformer = Qformer
63
+ self.query_tokens = query_tokens
64
+ self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
65
+ self.ln_vision = LayerNorm(encoder_config.encoder_width)
66
+ self.ln_llava = LayerNorm(encoder_config.encoder_width)
67
+
68
+ tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
69
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
70
+ self.Qformer.resize_token_embeddings(len(tokenizer))
71
+
72
+ if qformer_pth is not None:
73
+ pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
74
+ print(f'Load Qformer from {qformer_pth}')
75
+ self.load_state_dict(pretrained_state_dict, strict=False)
76
+ print('Done.')
77
+
78
+ projector_config = ProjectorConfig(
79
+ visual_hidden_size = config.visual_hidden_size,
80
+ llm_hidden_size = config.llm_hidden_size,
81
+ projector_depth = 2)
82
+ self.connector = ProjectorModel(projector_config)
83
+
84
+ modules = [
85
+ nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
86
+ ACT2FN['gelu'],
87
+ nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
88
+ ]
89
+ self.ffn = nn.Sequential(*modules)
90
+
91
+ def enable_input_require_grads(self):
92
+ def make_inputs_require_grad(module, input, output):
93
+ if isinstance(output, tuple):
94
+ output[0].requires_grad_(True)
95
+ output[1].requires_grad_(True)
96
+ else:
97
+ output.requires_grad_(True)
98
+
99
+ self.Qformer.register_forward_hook(make_inputs_require_grad)
100
+ self.llm_proj.register_forward_hook(make_inputs_require_grad)
101
+ self.ln_vision.register_forward_hook(make_inputs_require_grad)
102
+ self.connector.register_forward_hook(make_inputs_require_grad)
103
+ self.ffn.register_forward_hook(make_inputs_require_grad)
104
+
105
+ def _set_gradient_checkpointing(self, module, value=False):
106
+ pass
107
+
108
+ def forward(self, x_):
109
+ if self.gradient_checkpointing and self.training:
110
+ print('Not support gradient checkpointing')
111
+ x = self.ln_vision(x_)
112
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
113
+ query_output = self.Qformer.bert(
114
+ query_embeds=query_tokens,
115
+ encoder_hidden_states=x,
116
+ return_dict=True,
117
+ )
118
+
119
+ q_feat = self.llm_proj(query_output.last_hidden_state)
120
+
121
+ mlp_outputs = self.connector(x_)
122
+ mlp_feat = mlp_outputs
123
+
124
+ int_feat = mlp_feat + q_feat.mean(dim=1)[:,None]
125
+ out = int_feat + self.ffn(int_feat)
126
+
127
+ return out
128
+
projector/modeling_projector.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.activations import ACT2FN
5
+
6
+ from .configuration_projector import ProjectorConfig
7
+
8
+
9
+ class ProjectorModel(PreTrainedModel):
10
+ _auto_class = 'AutoModel'
11
+ config_class = ProjectorConfig
12
+ base_model_prefix = 'model'
13
+ supports_gradient_checkpointing = True
14
+
15
+ def __init__(self, config: ProjectorConfig) -> None:
16
+ super().__init__(config)
17
+ self.gradient_checkpointing = False
18
+
19
+ modules = [
20
+ nn.Linear(
21
+ config.visual_hidden_size,
22
+ config.llm_hidden_size,
23
+ bias=config.bias)
24
+ ]
25
+ for _ in range(1, config.depth):
26
+ modules.append(ACT2FN[config.hidden_act])
27
+ modules.append(
28
+ nn.Linear(
29
+ config.llm_hidden_size,
30
+ config.llm_hidden_size,
31
+ bias=config.bias))
32
+ self.model = nn.Sequential(*modules)
33
+
34
+ def enable_input_require_grads(self):
35
+
36
+ def make_inputs_require_grad(module, input, output):
37
+ output.requires_grad_(True)
38
+
39
+ self.model.register_forward_hook(make_inputs_require_grad)
40
+
41
+ def _set_gradient_checkpointing(self, module, value=False):
42
+ if isinstance(module, ProjectorModel):
43
+ module.gradient_checkpointing = value
44
+
45
+ def forward(self, x):
46
+ if self.gradient_checkpointing and self.training:
47
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
48
+ else:
49
+ layer_outputs = self.model(x)
50
+ return layer_outputs
projector/qformer_src.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Dict, Any
6
+
7
+ import torch
8
+ from torch import Tensor, device, dtype, nn
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.file_utils import (
16
+ ModelOutput,
17
+ )
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ CausalLMOutputWithCrossAttentions,
22
+ MaskedLMOutput,
23
+ MultipleChoiceModelOutput,
24
+ NextSentencePredictorOutput,
25
+ QuestionAnsweringModelOutput,
26
+ SequenceClassifierOutput,
27
+ TokenClassifierOutput,
28
+ )
29
+ from transformers.modeling_utils import (
30
+ PreTrainedModel,
31
+ apply_chunking_to_forward,
32
+ find_pruneable_heads_and_indices,
33
+ prune_linear_layer,
34
+ )
35
+ from transformers.utils import logging
36
+ from transformers.models.bert.configuration_bert import BertConfig
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ class BertEmbeddings(nn.Module):
42
+ """Construct the embeddings from word and position embeddings."""
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.word_embeddings = nn.Embedding(
47
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
48
+ )
49
+ self.position_embeddings = nn.Embedding(
50
+ config.max_position_embeddings, config.hidden_size
51
+ )
52
+
53
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
54
+ # any TensorFlow checkpoint file
55
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
56
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
57
+
58
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
59
+ self.register_buffer(
60
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
61
+ )
62
+ self.position_embedding_type = getattr(
63
+ config, "position_embedding_type", "absolute"
64
+ )
65
+
66
+ self.config = config
67
+
68
+ def forward(
69
+ self,
70
+ input_ids=None,
71
+ position_ids=None,
72
+ query_embeds=None,
73
+ past_key_values_length=0,
74
+ ):
75
+ if input_ids is not None:
76
+ seq_length = input_ids.size()[1]
77
+ else:
78
+ seq_length = 0
79
+
80
+ if position_ids is None:
81
+ position_ids = self.position_ids[
82
+ :, past_key_values_length : seq_length + past_key_values_length
83
+ ].clone()
84
+
85
+ if input_ids is not None:
86
+ embeddings = self.word_embeddings(input_ids)
87
+ if self.position_embedding_type == "absolute":
88
+ position_embeddings = self.position_embeddings(position_ids)
89
+ embeddings = embeddings + position_embeddings
90
+
91
+ if query_embeds is not None:
92
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
93
+ else:
94
+ embeddings = query_embeds
95
+
96
+ embeddings = self.LayerNorm(embeddings)
97
+ embeddings = self.dropout(embeddings)
98
+ return embeddings
99
+
100
+
101
+ class BertSelfAttention(nn.Module):
102
+ def __init__(self, config, is_cross_attention):
103
+ super().__init__()
104
+ self.config = config
105
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
106
+ config, "embedding_size"
107
+ ):
108
+ raise ValueError(
109
+ "The hidden size (%d) is not a multiple of the number of attention "
110
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
111
+ )
112
+
113
+ self.num_attention_heads = config.num_attention_heads
114
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
115
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
116
+
117
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
118
+ if is_cross_attention:
119
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
120
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
121
+ else:
122
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
123
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
124
+
125
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
126
+ self.position_embedding_type = getattr(
127
+ config, "position_embedding_type", "absolute"
128
+ )
129
+ if (
130
+ self.position_embedding_type == "relative_key"
131
+ or self.position_embedding_type == "relative_key_query"
132
+ ):
133
+ self.max_position_embeddings = config.max_position_embeddings
134
+ self.distance_embedding = nn.Embedding(
135
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
136
+ )
137
+ self.save_attention = False
138
+
139
+ def save_attn_gradients(self, attn_gradients):
140
+ self.attn_gradients = attn_gradients
141
+
142
+ def get_attn_gradients(self):
143
+ return self.attn_gradients
144
+
145
+ def save_attention_map(self, attention_map):
146
+ self.attention_map = attention_map
147
+
148
+ def get_attention_map(self):
149
+ return self.attention_map
150
+
151
+ def transpose_for_scores(self, x):
152
+ new_x_shape = x.size()[:-1] + (
153
+ self.num_attention_heads,
154
+ self.attention_head_size,
155
+ )
156
+ x = x.view(*new_x_shape)
157
+ return x.permute(0, 2, 1, 3)
158
+
159
+ def forward(
160
+ self,
161
+ hidden_states,
162
+ attention_mask=None,
163
+ head_mask=None,
164
+ encoder_hidden_states=None,
165
+ encoder_attention_mask=None,
166
+ past_key_value=None,
167
+ output_attentions=False,
168
+ ):
169
+
170
+ # If this is instantiated as a cross-attention module, the keys
171
+ # and values come from an encoder; the attention mask needs to be
172
+ # such that the encoder's padding tokens are not attended to.
173
+ is_cross_attention = encoder_hidden_states is not None
174
+
175
+ if is_cross_attention:
176
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
177
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
178
+ attention_mask = encoder_attention_mask
179
+ elif past_key_value is not None:
180
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
181
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
182
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
183
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
184
+ else:
185
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
186
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
187
+
188
+ mixed_query_layer = self.query(hidden_states)
189
+
190
+ query_layer = self.transpose_for_scores(mixed_query_layer)
191
+
192
+ past_key_value = (key_layer, value_layer)
193
+
194
+ # Take the dot product between "query" and "key" to get the raw attention scores.
195
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
196
+
197
+ if (
198
+ self.position_embedding_type == "relative_key"
199
+ or self.position_embedding_type == "relative_key_query"
200
+ ):
201
+ seq_length = hidden_states.size()[1]
202
+ position_ids_l = torch.arange(
203
+ seq_length, dtype=torch.long, device=hidden_states.device
204
+ ).view(-1, 1)
205
+ position_ids_r = torch.arange(
206
+ seq_length, dtype=torch.long, device=hidden_states.device
207
+ ).view(1, -1)
208
+ distance = position_ids_l - position_ids_r
209
+ positional_embedding = self.distance_embedding(
210
+ distance + self.max_position_embeddings - 1
211
+ )
212
+ positional_embedding = positional_embedding.to(
213
+ dtype=query_layer.dtype
214
+ ) # fp16 compatibility
215
+
216
+ if self.position_embedding_type == "relative_key":
217
+ relative_position_scores = torch.einsum(
218
+ "bhld,lrd->bhlr", query_layer, positional_embedding
219
+ )
220
+ attention_scores = attention_scores + relative_position_scores
221
+ elif self.position_embedding_type == "relative_key_query":
222
+ relative_position_scores_query = torch.einsum(
223
+ "bhld,lrd->bhlr", query_layer, positional_embedding
224
+ )
225
+ relative_position_scores_key = torch.einsum(
226
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
227
+ )
228
+ attention_scores = (
229
+ attention_scores
230
+ + relative_position_scores_query
231
+ + relative_position_scores_key
232
+ )
233
+
234
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
235
+ if attention_mask is not None:
236
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
237
+ attention_scores = attention_scores + attention_mask
238
+
239
+ # Normalize the attention scores to probabilities.
240
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
241
+
242
+ if is_cross_attention and self.save_attention:
243
+ self.save_attention_map(attention_probs)
244
+ attention_probs.register_hook(self.save_attn_gradients)
245
+
246
+ # This is actually dropping out entire tokens to attend to, which might
247
+ # seem a bit unusual, but is taken from the original Transformer paper.
248
+ attention_probs_dropped = self.dropout(attention_probs)
249
+
250
+ # Mask heads if we want to
251
+ if head_mask is not None:
252
+ attention_probs_dropped = attention_probs_dropped * head_mask
253
+
254
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
255
+
256
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
257
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
258
+ context_layer = context_layer.view(*new_context_layer_shape)
259
+
260
+ outputs = (
261
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
262
+ )
263
+
264
+ outputs = outputs + (past_key_value,)
265
+ return outputs
266
+
267
+
268
+ class BertSelfOutput(nn.Module):
269
+ def __init__(self, config):
270
+ super().__init__()
271
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
272
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
273
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
274
+
275
+ def forward(self, hidden_states, input_tensor):
276
+ hidden_states = self.dense(hidden_states)
277
+ hidden_states = self.dropout(hidden_states)
278
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
279
+ return hidden_states
280
+
281
+
282
+ class BertAttention(nn.Module):
283
+ def __init__(self, config, is_cross_attention=False):
284
+ super().__init__()
285
+ self.self = BertSelfAttention(config, is_cross_attention)
286
+ self.output = BertSelfOutput(config)
287
+ self.pruned_heads = set()
288
+
289
+ def prune_heads(self, heads):
290
+ if len(heads) == 0:
291
+ return
292
+ heads, index = find_pruneable_heads_and_indices(
293
+ heads,
294
+ self.self.num_attention_heads,
295
+ self.self.attention_head_size,
296
+ self.pruned_heads,
297
+ )
298
+
299
+ # Prune linear layers
300
+ self.self.query = prune_linear_layer(self.self.query, index)
301
+ self.self.key = prune_linear_layer(self.self.key, index)
302
+ self.self.value = prune_linear_layer(self.self.value, index)
303
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
304
+
305
+ # Update hyper params and store pruned heads
306
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
307
+ self.self.all_head_size = (
308
+ self.self.attention_head_size * self.self.num_attention_heads
309
+ )
310
+ self.pruned_heads = self.pruned_heads.union(heads)
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states,
315
+ attention_mask=None,
316
+ head_mask=None,
317
+ encoder_hidden_states=None,
318
+ encoder_attention_mask=None,
319
+ past_key_value=None,
320
+ output_attentions=False,
321
+ ):
322
+ self_outputs = self.self(
323
+ hidden_states,
324
+ attention_mask,
325
+ head_mask,
326
+ encoder_hidden_states,
327
+ encoder_attention_mask,
328
+ past_key_value,
329
+ output_attentions,
330
+ )
331
+ attention_output = self.output(self_outputs[0], hidden_states)
332
+
333
+ outputs = (attention_output,) + self_outputs[
334
+ 1:
335
+ ] # add attentions if we output them
336
+ return outputs
337
+
338
+
339
+ class BertIntermediate(nn.Module):
340
+ def __init__(self, config):
341
+ super().__init__()
342
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
343
+ if isinstance(config.hidden_act, str):
344
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
345
+ else:
346
+ self.intermediate_act_fn = config.hidden_act
347
+
348
+ def forward(self, hidden_states):
349
+ hidden_states = self.dense(hidden_states)
350
+ hidden_states = self.intermediate_act_fn(hidden_states)
351
+ return hidden_states
352
+
353
+
354
+ class BertOutput(nn.Module):
355
+ def __init__(self, config):
356
+ super().__init__()
357
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
358
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
359
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
360
+
361
+ def forward(self, hidden_states, input_tensor):
362
+ hidden_states = self.dense(hidden_states)
363
+ hidden_states = self.dropout(hidden_states)
364
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
365
+ return hidden_states
366
+
367
+
368
+ class BertLayer(nn.Module):
369
+ def __init__(self, config, layer_num):
370
+ super().__init__()
371
+ self.config = config
372
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
373
+ self.seq_len_dim = 1
374
+ self.attention = BertAttention(config)
375
+ self.layer_num = layer_num
376
+ if (
377
+ self.config.add_cross_attention
378
+ and layer_num % self.config.cross_attention_freq == 0
379
+ ):
380
+ self.crossattention = BertAttention(
381
+ config, is_cross_attention=self.config.add_cross_attention
382
+ )
383
+ self.has_cross_attention = True
384
+ else:
385
+ self.has_cross_attention = False
386
+ self.intermediate = BertIntermediate(config)
387
+ self.output = BertOutput(config)
388
+
389
+ self.intermediate_query = BertIntermediate(config)
390
+ self.output_query = BertOutput(config)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states,
395
+ attention_mask=None,
396
+ head_mask=None,
397
+ encoder_hidden_states=None,
398
+ encoder_attention_mask=None,
399
+ past_key_value=None,
400
+ output_attentions=False,
401
+ query_length=0,
402
+ ):
403
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
404
+ self_attn_past_key_value = (
405
+ past_key_value[:2] if past_key_value is not None else None
406
+ )
407
+ self_attention_outputs = self.attention(
408
+ hidden_states,
409
+ attention_mask,
410
+ head_mask,
411
+ output_attentions=output_attentions,
412
+ past_key_value=self_attn_past_key_value,
413
+ )
414
+ attention_output = self_attention_outputs[0]
415
+ outputs = self_attention_outputs[1:-1]
416
+
417
+ present_key_value = self_attention_outputs[-1]
418
+
419
+ if query_length > 0:
420
+ query_attention_output = attention_output[:, :query_length, :]
421
+
422
+ if self.has_cross_attention:
423
+ assert (
424
+ encoder_hidden_states is not None
425
+ ), "encoder_hidden_states must be given for cross-attention layers"
426
+ cross_attention_outputs = self.crossattention(
427
+ query_attention_output,
428
+ attention_mask,
429
+ head_mask,
430
+ encoder_hidden_states,
431
+ encoder_attention_mask,
432
+ output_attentions=output_attentions,
433
+ )
434
+ query_attention_output = cross_attention_outputs[0]
435
+ outputs = (
436
+ outputs + cross_attention_outputs[1:-1]
437
+ ) # add cross attentions if we output attention weights
438
+
439
+ layer_output = apply_chunking_to_forward(
440
+ self.feed_forward_chunk_query,
441
+ self.chunk_size_feed_forward,
442
+ self.seq_len_dim,
443
+ query_attention_output,
444
+ )
445
+ if attention_output.shape[1] > query_length:
446
+ layer_output_text = apply_chunking_to_forward(
447
+ self.feed_forward_chunk,
448
+ self.chunk_size_feed_forward,
449
+ self.seq_len_dim,
450
+ attention_output[:, query_length:, :],
451
+ )
452
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
453
+ else:
454
+ layer_output = apply_chunking_to_forward(
455
+ self.feed_forward_chunk,
456
+ self.chunk_size_feed_forward,
457
+ self.seq_len_dim,
458
+ attention_output,
459
+ )
460
+ outputs = (layer_output,) + outputs
461
+
462
+ outputs = outputs + (present_key_value,)
463
+
464
+ return outputs
465
+
466
+ def feed_forward_chunk(self, attention_output):
467
+ intermediate_output = self.intermediate(attention_output)
468
+ layer_output = self.output(intermediate_output, attention_output)
469
+ return layer_output
470
+
471
+ def feed_forward_chunk_query(self, attention_output):
472
+ intermediate_output = self.intermediate_query(attention_output)
473
+ layer_output = self.output_query(intermediate_output, attention_output)
474
+ return layer_output
475
+
476
+
477
+ class BertEncoder(nn.Module):
478
+ def __init__(self, config):
479
+ super().__init__()
480
+ self.config = config
481
+ self.layer = nn.ModuleList(
482
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
483
+ )
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states,
488
+ attention_mask=None,
489
+ head_mask=None,
490
+ encoder_hidden_states=None,
491
+ encoder_attention_mask=None,
492
+ past_key_values=None,
493
+ use_cache=None,
494
+ output_attentions=False,
495
+ output_hidden_states=False,
496
+ return_dict=True,
497
+ query_length=0,
498
+ ):
499
+ all_hidden_states = () if output_hidden_states else None
500
+ all_self_attentions = () if output_attentions else None
501
+ all_cross_attentions = (
502
+ () if output_attentions and self.config.add_cross_attention else None
503
+ )
504
+
505
+ next_decoder_cache = () if use_cache else None
506
+
507
+ for i in range(self.config.num_hidden_layers):
508
+ layer_module = self.layer[i]
509
+ if output_hidden_states:
510
+ all_hidden_states = all_hidden_states + (hidden_states,)
511
+
512
+ layer_head_mask = head_mask[i] if head_mask is not None else None
513
+ past_key_value = past_key_values[i] if past_key_values is not None else None
514
+
515
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
516
+
517
+ if use_cache:
518
+ logger.warn(
519
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
520
+ )
521
+ use_cache = False
522
+
523
+ def create_custom_forward(module):
524
+ def custom_forward(*inputs):
525
+ return module(
526
+ *inputs, past_key_value, output_attentions, query_length
527
+ )
528
+
529
+ return custom_forward
530
+
531
+ layer_outputs = torch.utils.checkpoint.checkpoint(
532
+ create_custom_forward(layer_module),
533
+ hidden_states,
534
+ attention_mask,
535
+ layer_head_mask,
536
+ encoder_hidden_states,
537
+ encoder_attention_mask,
538
+ )
539
+ else:
540
+ layer_outputs = layer_module(
541
+ hidden_states,
542
+ attention_mask,
543
+ layer_head_mask,
544
+ encoder_hidden_states,
545
+ encoder_attention_mask,
546
+ past_key_value,
547
+ output_attentions,
548
+ query_length,
549
+ )
550
+
551
+ hidden_states = layer_outputs[0]
552
+ if use_cache:
553
+ next_decoder_cache += (layer_outputs[-1],)
554
+ if output_attentions:
555
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
556
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
557
+
558
+ if output_hidden_states:
559
+ all_hidden_states = all_hidden_states + (hidden_states,)
560
+
561
+ if not return_dict:
562
+ return tuple(
563
+ v
564
+ for v in [
565
+ hidden_states,
566
+ next_decoder_cache,
567
+ all_hidden_states,
568
+ all_self_attentions,
569
+ all_cross_attentions,
570
+ ]
571
+ if v is not None
572
+ )
573
+ return BaseModelOutputWithPastAndCrossAttentions(
574
+ last_hidden_state=hidden_states,
575
+ past_key_values=next_decoder_cache,
576
+ hidden_states=all_hidden_states,
577
+ attentions=all_self_attentions,
578
+ cross_attentions=all_cross_attentions,
579
+ )
580
+
581
+
582
+ class BertPooler(nn.Module):
583
+ def __init__(self, config):
584
+ super().__init__()
585
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
586
+ self.activation = nn.Tanh()
587
+
588
+ def forward(self, hidden_states):
589
+ # We "pool" the model by simply taking the hidden state corresponding
590
+ # to the first token.
591
+ first_token_tensor = hidden_states[:, 0]
592
+ pooled_output = self.dense(first_token_tensor)
593
+ pooled_output = self.activation(pooled_output)
594
+ return pooled_output
595
+
596
+
597
+ class BertPredictionHeadTransform(nn.Module):
598
+ def __init__(self, config):
599
+ super().__init__()
600
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
601
+ if isinstance(config.hidden_act, str):
602
+ self.transform_act_fn = ACT2FN[config.hidden_act]
603
+ else:
604
+ self.transform_act_fn = config.hidden_act
605
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
606
+
607
+ def forward(self, hidden_states):
608
+ hidden_states = self.dense(hidden_states)
609
+ hidden_states = self.transform_act_fn(hidden_states)
610
+ hidden_states = self.LayerNorm(hidden_states)
611
+ return hidden_states
612
+
613
+
614
+ class BertLMPredictionHead(nn.Module):
615
+ def __init__(self, config):
616
+ super().__init__()
617
+ self.transform = BertPredictionHeadTransform(config)
618
+
619
+ # The output weights are the same as the input embeddings, but there is
620
+ # an output-only bias for each token.
621
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
622
+
623
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
624
+
625
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
626
+ self.decoder.bias = self.bias
627
+
628
+ def forward(self, hidden_states):
629
+ hidden_states = self.transform(hidden_states)
630
+ hidden_states = self.decoder(hidden_states)
631
+ return hidden_states
632
+
633
+
634
+ class BertOnlyMLMHead(nn.Module):
635
+ def __init__(self, config):
636
+ super().__init__()
637
+ self.predictions = BertLMPredictionHead(config)
638
+
639
+ def forward(self, sequence_output):
640
+ prediction_scores = self.predictions(sequence_output)
641
+ return prediction_scores
642
+
643
+
644
+ class BertPreTrainedModel(PreTrainedModel):
645
+ """
646
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
647
+ models.
648
+ """
649
+
650
+ config_class = BertConfig
651
+ base_model_prefix = "bert"
652
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
653
+
654
+ def _init_weights(self, module):
655
+ """Initialize the weights"""
656
+ if isinstance(module, (nn.Linear, nn.Embedding)):
657
+ # Slightly different from the TF version which uses truncated_normal for initialization
658
+ # cf https://github.com/pytorch/pytorch/pull/5617
659
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
660
+ elif isinstance(module, nn.LayerNorm):
661
+ module.bias.data.zero_()
662
+ module.weight.data.fill_(1.0)
663
+ if isinstance(module, nn.Linear) and module.bias is not None:
664
+ module.bias.data.zero_()
665
+
666
+
667
+ class BertModel(BertPreTrainedModel):
668
+ """
669
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
670
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
671
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
672
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
673
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
674
+ input to the forward pass.
675
+ """
676
+
677
+ def __init__(self, config, add_pooling_layer=False):
678
+ super().__init__(config)
679
+ self.config = config
680
+
681
+ self.embeddings = BertEmbeddings(config)
682
+
683
+ self.encoder = BertEncoder(config)
684
+
685
+ self.pooler = BertPooler(config) if add_pooling_layer else None
686
+
687
+ self.init_weights()
688
+
689
+ def get_input_embeddings(self):
690
+ return self.embeddings.word_embeddings
691
+
692
+ def set_input_embeddings(self, value):
693
+ self.embeddings.word_embeddings = value
694
+
695
+ def _prune_heads(self, heads_to_prune):
696
+ """
697
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
698
+ class PreTrainedModel
699
+ """
700
+ for layer, heads in heads_to_prune.items():
701
+ self.encoder.layer[layer].attention.prune_heads(heads)
702
+
703
+ def get_extended_attention_mask(
704
+ self,
705
+ attention_mask: Tensor,
706
+ input_shape: Tuple[int],
707
+ device: device,
708
+ is_decoder: bool,
709
+ has_query: bool = False,
710
+ ) -> Tensor:
711
+ """
712
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
713
+
714
+ Arguments:
715
+ attention_mask (:obj:`torch.Tensor`):
716
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
717
+ input_shape (:obj:`Tuple[int]`):
718
+ The shape of the input to the model.
719
+ device: (:obj:`torch.device`):
720
+ The device of the input to the model.
721
+
722
+ Returns:
723
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
724
+ """
725
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
726
+ # ourselves in which case we just need to make it broadcastable to all heads.
727
+ if attention_mask.dim() == 3:
728
+ extended_attention_mask = attention_mask[:, None, :, :]
729
+ elif attention_mask.dim() == 2:
730
+ # Provided a padding mask of dimensions [batch_size, seq_length]
731
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
732
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
733
+ if is_decoder:
734
+ batch_size, seq_length = input_shape
735
+
736
+ seq_ids = torch.arange(seq_length, device=device)
737
+ causal_mask = (
738
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
739
+ <= seq_ids[None, :, None]
740
+ )
741
+
742
+ # add a prefix ones mask to the causal mask
743
+ # causal and attention masks must have same type with pytorch version < 1.3
744
+ causal_mask = causal_mask.to(attention_mask.dtype)
745
+
746
+ if causal_mask.shape[1] < attention_mask.shape[1]:
747
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
748
+ if has_query: # UniLM style attention mask
749
+ causal_mask = torch.cat(
750
+ [
751
+ torch.zeros(
752
+ (batch_size, prefix_seq_len, seq_length),
753
+ device=device,
754
+ dtype=causal_mask.dtype,
755
+ ),
756
+ causal_mask,
757
+ ],
758
+ axis=1,
759
+ )
760
+ causal_mask = torch.cat(
761
+ [
762
+ torch.ones(
763
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
764
+ device=device,
765
+ dtype=causal_mask.dtype,
766
+ ),
767
+ causal_mask,
768
+ ],
769
+ axis=-1,
770
+ )
771
+ extended_attention_mask = (
772
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
773
+ )
774
+ else:
775
+ extended_attention_mask = attention_mask[:, None, None, :]
776
+ else:
777
+ raise ValueError(
778
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
779
+ input_shape, attention_mask.shape
780
+ )
781
+ )
782
+
783
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
784
+ # masked positions, this operation will create a tensor which is 0.0 for
785
+ # positions we want to attend and -10000.0 for masked positions.
786
+ # Since we are adding it to the raw scores before the softmax, this is
787
+ # effectively the same as removing these entirely.
788
+ extended_attention_mask = extended_attention_mask.to(
789
+ dtype=self.dtype
790
+ ) # fp16 compatibility
791
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
792
+ return extended_attention_mask
793
+
794
+ def forward(
795
+ self,
796
+ input_ids=None,
797
+ attention_mask=None,
798
+ position_ids=None,
799
+ head_mask=None,
800
+ query_embeds=None,
801
+ encoder_hidden_states=None,
802
+ encoder_attention_mask=None,
803
+ past_key_values=None,
804
+ use_cache=None,
805
+ output_attentions=None,
806
+ output_hidden_states=None,
807
+ return_dict=None,
808
+ is_decoder=False,
809
+ ):
810
+ r"""
811
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
812
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
813
+ the model is configured as a decoder.
814
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
815
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
816
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
817
+ - 1 for tokens that are **not masked**,
818
+ - 0 for tokens that are **masked**.
819
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
820
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
821
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
822
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
823
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
824
+ use_cache (:obj:`bool`, `optional`):
825
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
826
+ decoding (see :obj:`past_key_values`).
827
+ """
828
+ output_attentions = (
829
+ output_attentions
830
+ if output_attentions is not None
831
+ else self.config.output_attentions
832
+ )
833
+ output_hidden_states = (
834
+ output_hidden_states
835
+ if output_hidden_states is not None
836
+ else self.config.output_hidden_states
837
+ )
838
+ return_dict = (
839
+ return_dict if return_dict is not None else self.config.use_return_dict
840
+ )
841
+
842
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
843
+
844
+ if input_ids is None:
845
+ assert (
846
+ query_embeds is not None
847
+ ), "You have to specify query_embeds when input_ids is None"
848
+
849
+ # past_key_values_length
850
+ past_key_values_length = (
851
+ past_key_values[0][0].shape[2] - self.config.query_length
852
+ if past_key_values is not None
853
+ else 0
854
+ )
855
+
856
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
857
+
858
+ embedding_output = self.embeddings(
859
+ input_ids=input_ids,
860
+ position_ids=position_ids,
861
+ query_embeds=query_embeds,
862
+ past_key_values_length=past_key_values_length,
863
+ )
864
+
865
+ input_shape = embedding_output.size()[:-1]
866
+ batch_size, seq_length = input_shape
867
+ device = embedding_output.device
868
+
869
+ if attention_mask is None:
870
+ attention_mask = torch.ones(
871
+ ((batch_size, seq_length + past_key_values_length)), device=device
872
+ )
873
+
874
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
875
+ # ourselves in which case we just need to make it broadcastable to all heads.
876
+ if is_decoder:
877
+ extended_attention_mask = self.get_extended_attention_mask(
878
+ attention_mask,
879
+ input_ids.shape,
880
+ device,
881
+ is_decoder,
882
+ has_query=(query_embeds is not None),
883
+ )
884
+ else:
885
+ extended_attention_mask = self.get_extended_attention_mask(
886
+ attention_mask, input_shape, device, is_decoder
887
+ )
888
+
889
+ # If a 2D or 3D attention mask is provided for the cross-attention
890
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
891
+ if encoder_hidden_states is not None:
892
+ if type(encoder_hidden_states) == list:
893
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
894
+ 0
895
+ ].size()
896
+ else:
897
+ (
898
+ encoder_batch_size,
899
+ encoder_sequence_length,
900
+ _,
901
+ ) = encoder_hidden_states.size()
902
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
903
+
904
+ if type(encoder_attention_mask) == list:
905
+ encoder_extended_attention_mask = [
906
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
907
+ ]
908
+ elif encoder_attention_mask is None:
909
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
910
+ encoder_extended_attention_mask = self.invert_attention_mask(
911
+ encoder_attention_mask
912
+ )
913
+ else:
914
+ encoder_extended_attention_mask = self.invert_attention_mask(
915
+ encoder_attention_mask
916
+ )
917
+ else:
918
+ encoder_extended_attention_mask = None
919
+
920
+ # Prepare head mask if needed
921
+ # 1.0 in head_mask indicate we keep the head
922
+ # attention_probs has shape bsz x n_heads x N x N
923
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
924
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
925
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
926
+
927
+ encoder_outputs = self.encoder(
928
+ embedding_output,
929
+ attention_mask=extended_attention_mask,
930
+ head_mask=head_mask,
931
+ encoder_hidden_states=encoder_hidden_states,
932
+ encoder_attention_mask=encoder_extended_attention_mask,
933
+ past_key_values=past_key_values,
934
+ use_cache=use_cache,
935
+ output_attentions=output_attentions,
936
+ output_hidden_states=output_hidden_states,
937
+ return_dict=return_dict,
938
+ query_length=query_length,
939
+ )
940
+ sequence_output = encoder_outputs[0]
941
+ pooled_output = (
942
+ self.pooler(sequence_output) if self.pooler is not None else None
943
+ )
944
+
945
+ if not return_dict:
946
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
947
+
948
+ return BaseModelOutputWithPoolingAndCrossAttentions(
949
+ last_hidden_state=sequence_output,
950
+ pooler_output=pooled_output,
951
+ past_key_values=encoder_outputs.past_key_values,
952
+ hidden_states=encoder_outputs.hidden_states,
953
+ attentions=encoder_outputs.attentions,
954
+ cross_attentions=encoder_outputs.cross_attentions,
955
+ )
956
+
957
+
958
+ class BertLMHeadModel(BertPreTrainedModel):
959
+
960
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
961
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
962
+
963
+ def __init__(self, config):
964
+ super().__init__(config)
965
+
966
+ self.bert = BertModel(config, add_pooling_layer=False)
967
+ self.cls = BertOnlyMLMHead(config)
968
+
969
+ self.init_weights()
970
+
971
+ def get_output_embeddings(self):
972
+ return self.cls.predictions.decoder
973
+
974
+ def set_output_embeddings(self, new_embeddings):
975
+ self.cls.predictions.decoder = new_embeddings
976
+
977
+ def forward(
978
+ self,
979
+ input_ids=None,
980
+ attention_mask=None,
981
+ position_ids=None,
982
+ head_mask=None,
983
+ query_embeds=None,
984
+ encoder_hidden_states=None,
985
+ encoder_attention_mask=None,
986
+ labels=None,
987
+ past_key_values=None,
988
+ use_cache=True,
989
+ output_attentions=None,
990
+ output_hidden_states=None,
991
+ return_dict=None,
992
+ return_logits=False,
993
+ is_decoder=True,
994
+ reduction="mean",
995
+ ):
996
+ r"""
997
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
998
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
999
+ the model is configured as a decoder.
1000
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1001
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1002
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1003
+ - 1 for tokens that are **not masked**,
1004
+ - 0 for tokens that are **masked**.
1005
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1006
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1007
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1008
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1009
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1010
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1011
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1012
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1013
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1014
+ use_cache (:obj:`bool`, `optional`):
1015
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1016
+ decoding (see :obj:`past_key_values`).
1017
+ Returns:
1018
+ Example::
1019
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1020
+ >>> import torch
1021
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1022
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1023
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1024
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1025
+ >>> outputs = model(**inputs)
1026
+ >>> prediction_logits = outputs.logits
1027
+ """
1028
+ return_dict = (
1029
+ return_dict if return_dict is not None else self.config.use_return_dict
1030
+ )
1031
+ if labels is not None:
1032
+ use_cache = False
1033
+ if past_key_values is not None:
1034
+ query_embeds = None
1035
+
1036
+ outputs = self.bert(
1037
+ input_ids,
1038
+ attention_mask=attention_mask,
1039
+ position_ids=position_ids,
1040
+ head_mask=head_mask,
1041
+ query_embeds=query_embeds,
1042
+ encoder_hidden_states=encoder_hidden_states,
1043
+ encoder_attention_mask=encoder_attention_mask,
1044
+ past_key_values=past_key_values,
1045
+ use_cache=use_cache,
1046
+ output_attentions=output_attentions,
1047
+ output_hidden_states=output_hidden_states,
1048
+ return_dict=return_dict,
1049
+ is_decoder=is_decoder,
1050
+ )
1051
+
1052
+ sequence_output = outputs[0]
1053
+ if query_embeds is not None:
1054
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1055
+
1056
+ prediction_scores = self.cls(sequence_output)
1057
+
1058
+ if return_logits:
1059
+ return prediction_scores[:, :-1, :].contiguous()
1060
+
1061
+ lm_loss = None
1062
+ if labels is not None:
1063
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1064
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1065
+ labels = labels[:, 1:].contiguous()
1066
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1067
+ lm_loss = loss_fct(
1068
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1069
+ labels.view(-1),
1070
+ )
1071
+ if reduction == "none":
1072
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1073
+
1074
+ if not return_dict:
1075
+ output = (prediction_scores,) + outputs[2:]
1076
+ return ((lm_loss,) + output) if lm_loss is not None else output
1077
+
1078
+ return CausalLMOutputWithCrossAttentions(
1079
+ loss=lm_loss,
1080
+ logits=prediction_scores,
1081
+ past_key_values=outputs.past_key_values,
1082
+ hidden_states=outputs.hidden_states,
1083
+ attentions=outputs.attentions,
1084
+ cross_attentions=outputs.cross_attentions,
1085
+ )
1086
+
1087
+ def prepare_inputs_for_generation(
1088
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1089
+ ):
1090
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1091
+ if attention_mask is None:
1092
+ attention_mask = input_ids.new_ones(input_ids.shape)
1093
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1094
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1095
+
1096
+ # cut decoder_input_ids if past is used
1097
+ if past is not None:
1098
+ input_ids = input_ids[:, -1:]
1099
+
1100
+ return {
1101
+ "input_ids": input_ids,
1102
+ "query_embeds": query_embeds,
1103
+ "attention_mask": attention_mask,
1104
+ "past_key_values": past,
1105
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1106
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1107
+ "is_decoder": True,
1108
+ }
1109
+
1110
+ def _reorder_cache(self, past, beam_idx):
1111
+ reordered_past = ()
1112
+ for layer_past in past:
1113
+ reordered_past += (
1114
+ tuple(
1115
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1116
+ ),
1117
+ )
1118
+ return reordered_past
1119
+
1120
+
1121
+ class BertForMaskedLM(BertPreTrainedModel):
1122
+
1123
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1124
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1125
+
1126
+ def __init__(self, config):
1127
+ super().__init__(config)
1128
+
1129
+ self.bert = BertModel(config, add_pooling_layer=False)
1130
+ self.cls = BertOnlyMLMHead(config)
1131
+
1132
+ self.init_weights()
1133
+
1134
+ def get_output_embeddings(self):
1135
+ return self.cls.predictions.decoder
1136
+
1137
+ def set_output_embeddings(self, new_embeddings):
1138
+ self.cls.predictions.decoder = new_embeddings
1139
+
1140
+ def forward(
1141
+ self,
1142
+ input_ids=None,
1143
+ attention_mask=None,
1144
+ position_ids=None,
1145
+ head_mask=None,
1146
+ query_embeds=None,
1147
+ encoder_hidden_states=None,
1148
+ encoder_attention_mask=None,
1149
+ labels=None,
1150
+ output_attentions=None,
1151
+ output_hidden_states=None,
1152
+ return_dict=None,
1153
+ return_logits=False,
1154
+ is_decoder=False,
1155
+ ):
1156
+ r"""
1157
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1158
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1159
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1160
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1161
+ """
1162
+
1163
+ return_dict = (
1164
+ return_dict if return_dict is not None else self.config.use_return_dict
1165
+ )
1166
+
1167
+ outputs = self.bert(
1168
+ input_ids,
1169
+ attention_mask=attention_mask,
1170
+ position_ids=position_ids,
1171
+ head_mask=head_mask,
1172
+ query_embeds=query_embeds,
1173
+ encoder_hidden_states=encoder_hidden_states,
1174
+ encoder_attention_mask=encoder_attention_mask,
1175
+ output_attentions=output_attentions,
1176
+ output_hidden_states=output_hidden_states,
1177
+ return_dict=return_dict,
1178
+ is_decoder=is_decoder,
1179
+ )
1180
+
1181
+ if query_embeds is not None:
1182
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1183
+ prediction_scores = self.cls(sequence_output)
1184
+
1185
+ if return_logits:
1186
+ return prediction_scores
1187
+
1188
+ masked_lm_loss = None
1189
+ if labels is not None:
1190
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1191
+ masked_lm_loss = loss_fct(
1192
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1193
+ )
1194
+
1195
+ if not return_dict:
1196
+ output = (prediction_scores,) + outputs[2:]
1197
+ return (
1198
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1199
+ )
1200
+
1201
+ return MaskedLMOutput(
1202
+ loss=masked_lm_loss,
1203
+ logits=prediction_scores,
1204
+ hidden_states=outputs.hidden_states,
1205
+ attentions=outputs.attentions,
1206
+ )
qformer_src.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Dict, Any
6
+
7
+ import torch
8
+ from torch import Tensor, device, dtype, nn
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.file_utils import (
16
+ ModelOutput,
17
+ )
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ CausalLMOutputWithCrossAttentions,
22
+ MaskedLMOutput,
23
+ MultipleChoiceModelOutput,
24
+ NextSentencePredictorOutput,
25
+ QuestionAnsweringModelOutput,
26
+ SequenceClassifierOutput,
27
+ TokenClassifierOutput,
28
+ )
29
+ from transformers.modeling_utils import (
30
+ PreTrainedModel,
31
+ apply_chunking_to_forward,
32
+ find_pruneable_heads_and_indices,
33
+ prune_linear_layer,
34
+ )
35
+ from transformers.utils import logging
36
+ from transformers.models.bert.configuration_bert import BertConfig
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ class BertEmbeddings(nn.Module):
42
+ """Construct the embeddings from word and position embeddings."""
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.word_embeddings = nn.Embedding(
47
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
48
+ )
49
+ self.position_embeddings = nn.Embedding(
50
+ config.max_position_embeddings, config.hidden_size
51
+ )
52
+
53
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
54
+ # any TensorFlow checkpoint file
55
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
56
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
57
+
58
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
59
+ self.register_buffer(
60
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
61
+ )
62
+ self.position_embedding_type = getattr(
63
+ config, "position_embedding_type", "absolute"
64
+ )
65
+
66
+ self.config = config
67
+
68
+ def forward(
69
+ self,
70
+ input_ids=None,
71
+ position_ids=None,
72
+ query_embeds=None,
73
+ past_key_values_length=0,
74
+ ):
75
+ if input_ids is not None:
76
+ seq_length = input_ids.size()[1]
77
+ else:
78
+ seq_length = 0
79
+
80
+ if position_ids is None:
81
+ position_ids = self.position_ids[
82
+ :, past_key_values_length : seq_length + past_key_values_length
83
+ ].clone()
84
+
85
+ if input_ids is not None:
86
+ embeddings = self.word_embeddings(input_ids)
87
+ if self.position_embedding_type == "absolute":
88
+ position_embeddings = self.position_embeddings(position_ids)
89
+ embeddings = embeddings + position_embeddings
90
+
91
+ if query_embeds is not None:
92
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
93
+ else:
94
+ embeddings = query_embeds
95
+
96
+ embeddings = self.LayerNorm(embeddings)
97
+ embeddings = self.dropout(embeddings)
98
+ return embeddings
99
+
100
+
101
+ class BertSelfAttention(nn.Module):
102
+ def __init__(self, config, is_cross_attention):
103
+ super().__init__()
104
+ self.config = config
105
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
106
+ config, "embedding_size"
107
+ ):
108
+ raise ValueError(
109
+ "The hidden size (%d) is not a multiple of the number of attention "
110
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
111
+ )
112
+
113
+ self.num_attention_heads = config.num_attention_heads
114
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
115
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
116
+
117
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
118
+ if is_cross_attention:
119
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
120
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
121
+ else:
122
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
123
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
124
+
125
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
126
+ self.position_embedding_type = getattr(
127
+ config, "position_embedding_type", "absolute"
128
+ )
129
+ if (
130
+ self.position_embedding_type == "relative_key"
131
+ or self.position_embedding_type == "relative_key_query"
132
+ ):
133
+ self.max_position_embeddings = config.max_position_embeddings
134
+ self.distance_embedding = nn.Embedding(
135
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
136
+ )
137
+ self.save_attention = False
138
+
139
+ def save_attn_gradients(self, attn_gradients):
140
+ self.attn_gradients = attn_gradients
141
+
142
+ def get_attn_gradients(self):
143
+ return self.attn_gradients
144
+
145
+ def save_attention_map(self, attention_map):
146
+ self.attention_map = attention_map
147
+
148
+ def get_attention_map(self):
149
+ return self.attention_map
150
+
151
+ def transpose_for_scores(self, x):
152
+ new_x_shape = x.size()[:-1] + (
153
+ self.num_attention_heads,
154
+ self.attention_head_size,
155
+ )
156
+ x = x.view(*new_x_shape)
157
+ return x.permute(0, 2, 1, 3)
158
+
159
+ def forward(
160
+ self,
161
+ hidden_states,
162
+ attention_mask=None,
163
+ head_mask=None,
164
+ encoder_hidden_states=None,
165
+ encoder_attention_mask=None,
166
+ past_key_value=None,
167
+ output_attentions=False,
168
+ ):
169
+
170
+ # If this is instantiated as a cross-attention module, the keys
171
+ # and values come from an encoder; the attention mask needs to be
172
+ # such that the encoder's padding tokens are not attended to.
173
+ is_cross_attention = encoder_hidden_states is not None
174
+
175
+ if is_cross_attention:
176
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
177
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
178
+ attention_mask = encoder_attention_mask
179
+ elif past_key_value is not None:
180
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
181
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
182
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
183
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
184
+ else:
185
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
186
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
187
+
188
+ mixed_query_layer = self.query(hidden_states)
189
+
190
+ query_layer = self.transpose_for_scores(mixed_query_layer)
191
+
192
+ past_key_value = (key_layer, value_layer)
193
+
194
+ # Take the dot product between "query" and "key" to get the raw attention scores.
195
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
196
+
197
+ if (
198
+ self.position_embedding_type == "relative_key"
199
+ or self.position_embedding_type == "relative_key_query"
200
+ ):
201
+ seq_length = hidden_states.size()[1]
202
+ position_ids_l = torch.arange(
203
+ seq_length, dtype=torch.long, device=hidden_states.device
204
+ ).view(-1, 1)
205
+ position_ids_r = torch.arange(
206
+ seq_length, dtype=torch.long, device=hidden_states.device
207
+ ).view(1, -1)
208
+ distance = position_ids_l - position_ids_r
209
+ positional_embedding = self.distance_embedding(
210
+ distance + self.max_position_embeddings - 1
211
+ )
212
+ positional_embedding = positional_embedding.to(
213
+ dtype=query_layer.dtype
214
+ ) # fp16 compatibility
215
+
216
+ if self.position_embedding_type == "relative_key":
217
+ relative_position_scores = torch.einsum(
218
+ "bhld,lrd->bhlr", query_layer, positional_embedding
219
+ )
220
+ attention_scores = attention_scores + relative_position_scores
221
+ elif self.position_embedding_type == "relative_key_query":
222
+ relative_position_scores_query = torch.einsum(
223
+ "bhld,lrd->bhlr", query_layer, positional_embedding
224
+ )
225
+ relative_position_scores_key = torch.einsum(
226
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
227
+ )
228
+ attention_scores = (
229
+ attention_scores
230
+ + relative_position_scores_query
231
+ + relative_position_scores_key
232
+ )
233
+
234
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
235
+ if attention_mask is not None:
236
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
237
+ attention_scores = attention_scores + attention_mask
238
+
239
+ # Normalize the attention scores to probabilities.
240
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
241
+
242
+ if is_cross_attention and self.save_attention:
243
+ self.save_attention_map(attention_probs)
244
+ attention_probs.register_hook(self.save_attn_gradients)
245
+
246
+ # This is actually dropping out entire tokens to attend to, which might
247
+ # seem a bit unusual, but is taken from the original Transformer paper.
248
+ attention_probs_dropped = self.dropout(attention_probs)
249
+
250
+ # Mask heads if we want to
251
+ if head_mask is not None:
252
+ attention_probs_dropped = attention_probs_dropped * head_mask
253
+
254
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
255
+
256
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
257
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
258
+ context_layer = context_layer.view(*new_context_layer_shape)
259
+
260
+ outputs = (
261
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
262
+ )
263
+
264
+ outputs = outputs + (past_key_value,)
265
+ return outputs
266
+
267
+
268
+ class BertSelfOutput(nn.Module):
269
+ def __init__(self, config):
270
+ super().__init__()
271
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
272
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
273
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
274
+
275
+ def forward(self, hidden_states, input_tensor):
276
+ hidden_states = self.dense(hidden_states)
277
+ hidden_states = self.dropout(hidden_states)
278
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
279
+ return hidden_states
280
+
281
+
282
+ class BertAttention(nn.Module):
283
+ def __init__(self, config, is_cross_attention=False):
284
+ super().__init__()
285
+ self.self = BertSelfAttention(config, is_cross_attention)
286
+ self.output = BertSelfOutput(config)
287
+ self.pruned_heads = set()
288
+
289
+ def prune_heads(self, heads):
290
+ if len(heads) == 0:
291
+ return
292
+ heads, index = find_pruneable_heads_and_indices(
293
+ heads,
294
+ self.self.num_attention_heads,
295
+ self.self.attention_head_size,
296
+ self.pruned_heads,
297
+ )
298
+
299
+ # Prune linear layers
300
+ self.self.query = prune_linear_layer(self.self.query, index)
301
+ self.self.key = prune_linear_layer(self.self.key, index)
302
+ self.self.value = prune_linear_layer(self.self.value, index)
303
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
304
+
305
+ # Update hyper params and store pruned heads
306
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
307
+ self.self.all_head_size = (
308
+ self.self.attention_head_size * self.self.num_attention_heads
309
+ )
310
+ self.pruned_heads = self.pruned_heads.union(heads)
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states,
315
+ attention_mask=None,
316
+ head_mask=None,
317
+ encoder_hidden_states=None,
318
+ encoder_attention_mask=None,
319
+ past_key_value=None,
320
+ output_attentions=False,
321
+ ):
322
+ self_outputs = self.self(
323
+ hidden_states,
324
+ attention_mask,
325
+ head_mask,
326
+ encoder_hidden_states,
327
+ encoder_attention_mask,
328
+ past_key_value,
329
+ output_attentions,
330
+ )
331
+ attention_output = self.output(self_outputs[0], hidden_states)
332
+
333
+ outputs = (attention_output,) + self_outputs[
334
+ 1:
335
+ ] # add attentions if we output them
336
+ return outputs
337
+
338
+
339
+ class BertIntermediate(nn.Module):
340
+ def __init__(self, config):
341
+ super().__init__()
342
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
343
+ if isinstance(config.hidden_act, str):
344
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
345
+ else:
346
+ self.intermediate_act_fn = config.hidden_act
347
+
348
+ def forward(self, hidden_states):
349
+ hidden_states = self.dense(hidden_states)
350
+ hidden_states = self.intermediate_act_fn(hidden_states)
351
+ return hidden_states
352
+
353
+
354
+ class BertOutput(nn.Module):
355
+ def __init__(self, config):
356
+ super().__init__()
357
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
358
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
359
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
360
+
361
+ def forward(self, hidden_states, input_tensor):
362
+ hidden_states = self.dense(hidden_states)
363
+ hidden_states = self.dropout(hidden_states)
364
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
365
+ return hidden_states
366
+
367
+
368
+ class BertLayer(nn.Module):
369
+ def __init__(self, config, layer_num):
370
+ super().__init__()
371
+ self.config = config
372
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
373
+ self.seq_len_dim = 1
374
+ self.attention = BertAttention(config)
375
+ self.layer_num = layer_num
376
+ if (
377
+ self.config.add_cross_attention
378
+ and layer_num % self.config.cross_attention_freq == 0
379
+ ):
380
+ self.crossattention = BertAttention(
381
+ config, is_cross_attention=self.config.add_cross_attention
382
+ )
383
+ self.has_cross_attention = True
384
+ else:
385
+ self.has_cross_attention = False
386
+ self.intermediate = BertIntermediate(config)
387
+ self.output = BertOutput(config)
388
+
389
+ self.intermediate_query = BertIntermediate(config)
390
+ self.output_query = BertOutput(config)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states,
395
+ attention_mask=None,
396
+ head_mask=None,
397
+ encoder_hidden_states=None,
398
+ encoder_attention_mask=None,
399
+ past_key_value=None,
400
+ output_attentions=False,
401
+ query_length=0,
402
+ ):
403
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
404
+ self_attn_past_key_value = (
405
+ past_key_value[:2] if past_key_value is not None else None
406
+ )
407
+ self_attention_outputs = self.attention(
408
+ hidden_states,
409
+ attention_mask,
410
+ head_mask,
411
+ output_attentions=output_attentions,
412
+ past_key_value=self_attn_past_key_value,
413
+ )
414
+ attention_output = self_attention_outputs[0]
415
+ outputs = self_attention_outputs[1:-1]
416
+
417
+ present_key_value = self_attention_outputs[-1]
418
+
419
+ if query_length > 0:
420
+ query_attention_output = attention_output[:, :query_length, :]
421
+
422
+ if self.has_cross_attention:
423
+ assert (
424
+ encoder_hidden_states is not None
425
+ ), "encoder_hidden_states must be given for cross-attention layers"
426
+ cross_attention_outputs = self.crossattention(
427
+ query_attention_output,
428
+ attention_mask,
429
+ head_mask,
430
+ encoder_hidden_states,
431
+ encoder_attention_mask,
432
+ output_attentions=output_attentions,
433
+ )
434
+ query_attention_output = cross_attention_outputs[0]
435
+ outputs = (
436
+ outputs + cross_attention_outputs[1:-1]
437
+ ) # add cross attentions if we output attention weights
438
+
439
+ layer_output = apply_chunking_to_forward(
440
+ self.feed_forward_chunk_query,
441
+ self.chunk_size_feed_forward,
442
+ self.seq_len_dim,
443
+ query_attention_output,
444
+ )
445
+ if attention_output.shape[1] > query_length:
446
+ layer_output_text = apply_chunking_to_forward(
447
+ self.feed_forward_chunk,
448
+ self.chunk_size_feed_forward,
449
+ self.seq_len_dim,
450
+ attention_output[:, query_length:, :],
451
+ )
452
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
453
+ else:
454
+ layer_output = apply_chunking_to_forward(
455
+ self.feed_forward_chunk,
456
+ self.chunk_size_feed_forward,
457
+ self.seq_len_dim,
458
+ attention_output,
459
+ )
460
+ outputs = (layer_output,) + outputs
461
+
462
+ outputs = outputs + (present_key_value,)
463
+
464
+ return outputs
465
+
466
+ def feed_forward_chunk(self, attention_output):
467
+ intermediate_output = self.intermediate(attention_output)
468
+ layer_output = self.output(intermediate_output, attention_output)
469
+ return layer_output
470
+
471
+ def feed_forward_chunk_query(self, attention_output):
472
+ intermediate_output = self.intermediate_query(attention_output)
473
+ layer_output = self.output_query(intermediate_output, attention_output)
474
+ return layer_output
475
+
476
+
477
+ class BertEncoder(nn.Module):
478
+ def __init__(self, config):
479
+ super().__init__()
480
+ self.config = config
481
+ self.layer = nn.ModuleList(
482
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
483
+ )
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states,
488
+ attention_mask=None,
489
+ head_mask=None,
490
+ encoder_hidden_states=None,
491
+ encoder_attention_mask=None,
492
+ past_key_values=None,
493
+ use_cache=None,
494
+ output_attentions=False,
495
+ output_hidden_states=False,
496
+ return_dict=True,
497
+ query_length=0,
498
+ ):
499
+ all_hidden_states = () if output_hidden_states else None
500
+ all_self_attentions = () if output_attentions else None
501
+ all_cross_attentions = (
502
+ () if output_attentions and self.config.add_cross_attention else None
503
+ )
504
+
505
+ next_decoder_cache = () if use_cache else None
506
+
507
+ for i in range(self.config.num_hidden_layers):
508
+ layer_module = self.layer[i]
509
+ if output_hidden_states:
510
+ all_hidden_states = all_hidden_states + (hidden_states,)
511
+
512
+ layer_head_mask = head_mask[i] if head_mask is not None else None
513
+ past_key_value = past_key_values[i] if past_key_values is not None else None
514
+
515
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
516
+
517
+ if use_cache:
518
+ logger.warn(
519
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
520
+ )
521
+ use_cache = False
522
+
523
+ def create_custom_forward(module):
524
+ def custom_forward(*inputs):
525
+ return module(
526
+ *inputs, past_key_value, output_attentions, query_length
527
+ )
528
+
529
+ return custom_forward
530
+
531
+ layer_outputs = torch.utils.checkpoint.checkpoint(
532
+ create_custom_forward(layer_module),
533
+ hidden_states,
534
+ attention_mask,
535
+ layer_head_mask,
536
+ encoder_hidden_states,
537
+ encoder_attention_mask,
538
+ )
539
+ else:
540
+ layer_outputs = layer_module(
541
+ hidden_states,
542
+ attention_mask,
543
+ layer_head_mask,
544
+ encoder_hidden_states,
545
+ encoder_attention_mask,
546
+ past_key_value,
547
+ output_attentions,
548
+ query_length,
549
+ )
550
+
551
+ hidden_states = layer_outputs[0]
552
+ if use_cache:
553
+ next_decoder_cache += (layer_outputs[-1],)
554
+ if output_attentions:
555
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
556
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
557
+
558
+ if output_hidden_states:
559
+ all_hidden_states = all_hidden_states + (hidden_states,)
560
+
561
+ if not return_dict:
562
+ return tuple(
563
+ v
564
+ for v in [
565
+ hidden_states,
566
+ next_decoder_cache,
567
+ all_hidden_states,
568
+ all_self_attentions,
569
+ all_cross_attentions,
570
+ ]
571
+ if v is not None
572
+ )
573
+ return BaseModelOutputWithPastAndCrossAttentions(
574
+ last_hidden_state=hidden_states,
575
+ past_key_values=next_decoder_cache,
576
+ hidden_states=all_hidden_states,
577
+ attentions=all_self_attentions,
578
+ cross_attentions=all_cross_attentions,
579
+ )
580
+
581
+
582
+ class BertPooler(nn.Module):
583
+ def __init__(self, config):
584
+ super().__init__()
585
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
586
+ self.activation = nn.Tanh()
587
+
588
+ def forward(self, hidden_states):
589
+ # We "pool" the model by simply taking the hidden state corresponding
590
+ # to the first token.
591
+ first_token_tensor = hidden_states[:, 0]
592
+ pooled_output = self.dense(first_token_tensor)
593
+ pooled_output = self.activation(pooled_output)
594
+ return pooled_output
595
+
596
+
597
+ class BertPredictionHeadTransform(nn.Module):
598
+ def __init__(self, config):
599
+ super().__init__()
600
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
601
+ if isinstance(config.hidden_act, str):
602
+ self.transform_act_fn = ACT2FN[config.hidden_act]
603
+ else:
604
+ self.transform_act_fn = config.hidden_act
605
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
606
+
607
+ def forward(self, hidden_states):
608
+ hidden_states = self.dense(hidden_states)
609
+ hidden_states = self.transform_act_fn(hidden_states)
610
+ hidden_states = self.LayerNorm(hidden_states)
611
+ return hidden_states
612
+
613
+
614
+ class BertLMPredictionHead(nn.Module):
615
+ def __init__(self, config):
616
+ super().__init__()
617
+ self.transform = BertPredictionHeadTransform(config)
618
+
619
+ # The output weights are the same as the input embeddings, but there is
620
+ # an output-only bias for each token.
621
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
622
+
623
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
624
+
625
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
626
+ self.decoder.bias = self.bias
627
+
628
+ def forward(self, hidden_states):
629
+ hidden_states = self.transform(hidden_states)
630
+ hidden_states = self.decoder(hidden_states)
631
+ return hidden_states
632
+
633
+
634
+ class BertOnlyMLMHead(nn.Module):
635
+ def __init__(self, config):
636
+ super().__init__()
637
+ self.predictions = BertLMPredictionHead(config)
638
+
639
+ def forward(self, sequence_output):
640
+ prediction_scores = self.predictions(sequence_output)
641
+ return prediction_scores
642
+
643
+
644
+ class BertPreTrainedModel(PreTrainedModel):
645
+ """
646
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
647
+ models.
648
+ """
649
+
650
+ config_class = BertConfig
651
+ base_model_prefix = "bert"
652
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
653
+
654
+ def _init_weights(self, module):
655
+ """Initialize the weights"""
656
+ if isinstance(module, (nn.Linear, nn.Embedding)):
657
+ # Slightly different from the TF version which uses truncated_normal for initialization
658
+ # cf https://github.com/pytorch/pytorch/pull/5617
659
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
660
+ elif isinstance(module, nn.LayerNorm):
661
+ module.bias.data.zero_()
662
+ module.weight.data.fill_(1.0)
663
+ if isinstance(module, nn.Linear) and module.bias is not None:
664
+ module.bias.data.zero_()
665
+
666
+
667
+ class BertModel(BertPreTrainedModel):
668
+ """
669
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
670
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
671
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
672
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
673
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
674
+ input to the forward pass.
675
+ """
676
+
677
+ def __init__(self, config, add_pooling_layer=False):
678
+ super().__init__(config)
679
+ self.config = config
680
+
681
+ self.embeddings = BertEmbeddings(config)
682
+
683
+ self.encoder = BertEncoder(config)
684
+
685
+ self.pooler = BertPooler(config) if add_pooling_layer else None
686
+
687
+ self.init_weights()
688
+
689
+ def get_input_embeddings(self):
690
+ return self.embeddings.word_embeddings
691
+
692
+ def set_input_embeddings(self, value):
693
+ self.embeddings.word_embeddings = value
694
+
695
+ def _prune_heads(self, heads_to_prune):
696
+ """
697
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
698
+ class PreTrainedModel
699
+ """
700
+ for layer, heads in heads_to_prune.items():
701
+ self.encoder.layer[layer].attention.prune_heads(heads)
702
+
703
+ def get_extended_attention_mask(
704
+ self,
705
+ attention_mask: Tensor,
706
+ input_shape: Tuple[int],
707
+ device: device,
708
+ is_decoder: bool,
709
+ has_query: bool = False,
710
+ ) -> Tensor:
711
+ """
712
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
713
+
714
+ Arguments:
715
+ attention_mask (:obj:`torch.Tensor`):
716
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
717
+ input_shape (:obj:`Tuple[int]`):
718
+ The shape of the input to the model.
719
+ device: (:obj:`torch.device`):
720
+ The device of the input to the model.
721
+
722
+ Returns:
723
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
724
+ """
725
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
726
+ # ourselves in which case we just need to make it broadcastable to all heads.
727
+ if attention_mask.dim() == 3:
728
+ extended_attention_mask = attention_mask[:, None, :, :]
729
+ elif attention_mask.dim() == 2:
730
+ # Provided a padding mask of dimensions [batch_size, seq_length]
731
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
732
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
733
+ if is_decoder:
734
+ batch_size, seq_length = input_shape
735
+
736
+ seq_ids = torch.arange(seq_length, device=device)
737
+ causal_mask = (
738
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
739
+ <= seq_ids[None, :, None]
740
+ )
741
+
742
+ # add a prefix ones mask to the causal mask
743
+ # causal and attention masks must have same type with pytorch version < 1.3
744
+ causal_mask = causal_mask.to(attention_mask.dtype)
745
+
746
+ if causal_mask.shape[1] < attention_mask.shape[1]:
747
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
748
+ if has_query: # UniLM style attention mask
749
+ causal_mask = torch.cat(
750
+ [
751
+ torch.zeros(
752
+ (batch_size, prefix_seq_len, seq_length),
753
+ device=device,
754
+ dtype=causal_mask.dtype,
755
+ ),
756
+ causal_mask,
757
+ ],
758
+ axis=1,
759
+ )
760
+ causal_mask = torch.cat(
761
+ [
762
+ torch.ones(
763
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
764
+ device=device,
765
+ dtype=causal_mask.dtype,
766
+ ),
767
+ causal_mask,
768
+ ],
769
+ axis=-1,
770
+ )
771
+ extended_attention_mask = (
772
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
773
+ )
774
+ else:
775
+ extended_attention_mask = attention_mask[:, None, None, :]
776
+ else:
777
+ raise ValueError(
778
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
779
+ input_shape, attention_mask.shape
780
+ )
781
+ )
782
+
783
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
784
+ # masked positions, this operation will create a tensor which is 0.0 for
785
+ # positions we want to attend and -10000.0 for masked positions.
786
+ # Since we are adding it to the raw scores before the softmax, this is
787
+ # effectively the same as removing these entirely.
788
+ extended_attention_mask = extended_attention_mask.to(
789
+ dtype=self.dtype
790
+ ) # fp16 compatibility
791
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
792
+ return extended_attention_mask
793
+
794
+ def forward(
795
+ self,
796
+ input_ids=None,
797
+ attention_mask=None,
798
+ position_ids=None,
799
+ head_mask=None,
800
+ query_embeds=None,
801
+ encoder_hidden_states=None,
802
+ encoder_attention_mask=None,
803
+ past_key_values=None,
804
+ use_cache=None,
805
+ output_attentions=None,
806
+ output_hidden_states=None,
807
+ return_dict=None,
808
+ is_decoder=False,
809
+ ):
810
+ r"""
811
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
812
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
813
+ the model is configured as a decoder.
814
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
815
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
816
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
817
+ - 1 for tokens that are **not masked**,
818
+ - 0 for tokens that are **masked**.
819
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
820
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
821
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
822
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
823
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
824
+ use_cache (:obj:`bool`, `optional`):
825
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
826
+ decoding (see :obj:`past_key_values`).
827
+ """
828
+ output_attentions = (
829
+ output_attentions
830
+ if output_attentions is not None
831
+ else self.config.output_attentions
832
+ )
833
+ output_hidden_states = (
834
+ output_hidden_states
835
+ if output_hidden_states is not None
836
+ else self.config.output_hidden_states
837
+ )
838
+ return_dict = (
839
+ return_dict if return_dict is not None else self.config.use_return_dict
840
+ )
841
+
842
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
843
+
844
+ if input_ids is None:
845
+ assert (
846
+ query_embeds is not None
847
+ ), "You have to specify query_embeds when input_ids is None"
848
+
849
+ # past_key_values_length
850
+ past_key_values_length = (
851
+ past_key_values[0][0].shape[2] - self.config.query_length
852
+ if past_key_values is not None
853
+ else 0
854
+ )
855
+
856
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
857
+
858
+ embedding_output = self.embeddings(
859
+ input_ids=input_ids,
860
+ position_ids=position_ids,
861
+ query_embeds=query_embeds,
862
+ past_key_values_length=past_key_values_length,
863
+ )
864
+
865
+ input_shape = embedding_output.size()[:-1]
866
+ batch_size, seq_length = input_shape
867
+ device = embedding_output.device
868
+
869
+ if attention_mask is None:
870
+ attention_mask = torch.ones(
871
+ ((batch_size, seq_length + past_key_values_length)), device=device
872
+ )
873
+
874
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
875
+ # ourselves in which case we just need to make it broadcastable to all heads.
876
+ if is_decoder:
877
+ extended_attention_mask = self.get_extended_attention_mask(
878
+ attention_mask,
879
+ input_ids.shape,
880
+ device,
881
+ is_decoder,
882
+ has_query=(query_embeds is not None),
883
+ )
884
+ else:
885
+ extended_attention_mask = self.get_extended_attention_mask(
886
+ attention_mask, input_shape, device, is_decoder
887
+ )
888
+
889
+ # If a 2D or 3D attention mask is provided for the cross-attention
890
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
891
+ if encoder_hidden_states is not None:
892
+ if type(encoder_hidden_states) == list:
893
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
894
+ 0
895
+ ].size()
896
+ else:
897
+ (
898
+ encoder_batch_size,
899
+ encoder_sequence_length,
900
+ _,
901
+ ) = encoder_hidden_states.size()
902
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
903
+
904
+ if type(encoder_attention_mask) == list:
905
+ encoder_extended_attention_mask = [
906
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
907
+ ]
908
+ elif encoder_attention_mask is None:
909
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
910
+ encoder_extended_attention_mask = self.invert_attention_mask(
911
+ encoder_attention_mask
912
+ )
913
+ else:
914
+ encoder_extended_attention_mask = self.invert_attention_mask(
915
+ encoder_attention_mask
916
+ )
917
+ else:
918
+ encoder_extended_attention_mask = None
919
+
920
+ # Prepare head mask if needed
921
+ # 1.0 in head_mask indicate we keep the head
922
+ # attention_probs has shape bsz x n_heads x N x N
923
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
924
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
925
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
926
+
927
+ encoder_outputs = self.encoder(
928
+ embedding_output,
929
+ attention_mask=extended_attention_mask,
930
+ head_mask=head_mask,
931
+ encoder_hidden_states=encoder_hidden_states,
932
+ encoder_attention_mask=encoder_extended_attention_mask,
933
+ past_key_values=past_key_values,
934
+ use_cache=use_cache,
935
+ output_attentions=output_attentions,
936
+ output_hidden_states=output_hidden_states,
937
+ return_dict=return_dict,
938
+ query_length=query_length,
939
+ )
940
+ sequence_output = encoder_outputs[0]
941
+ pooled_output = (
942
+ self.pooler(sequence_output) if self.pooler is not None else None
943
+ )
944
+
945
+ if not return_dict:
946
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
947
+
948
+ return BaseModelOutputWithPoolingAndCrossAttentions(
949
+ last_hidden_state=sequence_output,
950
+ pooler_output=pooled_output,
951
+ past_key_values=encoder_outputs.past_key_values,
952
+ hidden_states=encoder_outputs.hidden_states,
953
+ attentions=encoder_outputs.attentions,
954
+ cross_attentions=encoder_outputs.cross_attentions,
955
+ )
956
+
957
+
958
+ class BertLMHeadModel(BertPreTrainedModel):
959
+
960
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
961
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
962
+
963
+ def __init__(self, config):
964
+ super().__init__(config)
965
+
966
+ self.bert = BertModel(config, add_pooling_layer=False)
967
+ self.cls = BertOnlyMLMHead(config)
968
+
969
+ self.init_weights()
970
+
971
+ def get_output_embeddings(self):
972
+ return self.cls.predictions.decoder
973
+
974
+ def set_output_embeddings(self, new_embeddings):
975
+ self.cls.predictions.decoder = new_embeddings
976
+
977
+ def forward(
978
+ self,
979
+ input_ids=None,
980
+ attention_mask=None,
981
+ position_ids=None,
982
+ head_mask=None,
983
+ query_embeds=None,
984
+ encoder_hidden_states=None,
985
+ encoder_attention_mask=None,
986
+ labels=None,
987
+ past_key_values=None,
988
+ use_cache=True,
989
+ output_attentions=None,
990
+ output_hidden_states=None,
991
+ return_dict=None,
992
+ return_logits=False,
993
+ is_decoder=True,
994
+ reduction="mean",
995
+ ):
996
+ r"""
997
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
998
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
999
+ the model is configured as a decoder.
1000
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1001
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1002
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1003
+ - 1 for tokens that are **not masked**,
1004
+ - 0 for tokens that are **masked**.
1005
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1006
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1007
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1008
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1009
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1010
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1011
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1012
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1013
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1014
+ use_cache (:obj:`bool`, `optional`):
1015
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1016
+ decoding (see :obj:`past_key_values`).
1017
+ Returns:
1018
+ Example::
1019
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1020
+ >>> import torch
1021
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1022
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1023
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1024
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1025
+ >>> outputs = model(**inputs)
1026
+ >>> prediction_logits = outputs.logits
1027
+ """
1028
+ return_dict = (
1029
+ return_dict if return_dict is not None else self.config.use_return_dict
1030
+ )
1031
+ if labels is not None:
1032
+ use_cache = False
1033
+ if past_key_values is not None:
1034
+ query_embeds = None
1035
+
1036
+ outputs = self.bert(
1037
+ input_ids,
1038
+ attention_mask=attention_mask,
1039
+ position_ids=position_ids,
1040
+ head_mask=head_mask,
1041
+ query_embeds=query_embeds,
1042
+ encoder_hidden_states=encoder_hidden_states,
1043
+ encoder_attention_mask=encoder_attention_mask,
1044
+ past_key_values=past_key_values,
1045
+ use_cache=use_cache,
1046
+ output_attentions=output_attentions,
1047
+ output_hidden_states=output_hidden_states,
1048
+ return_dict=return_dict,
1049
+ is_decoder=is_decoder,
1050
+ )
1051
+
1052
+ sequence_output = outputs[0]
1053
+ if query_embeds is not None:
1054
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1055
+
1056
+ prediction_scores = self.cls(sequence_output)
1057
+
1058
+ if return_logits:
1059
+ return prediction_scores[:, :-1, :].contiguous()
1060
+
1061
+ lm_loss = None
1062
+ if labels is not None:
1063
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1064
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1065
+ labels = labels[:, 1:].contiguous()
1066
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1067
+ lm_loss = loss_fct(
1068
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1069
+ labels.view(-1),
1070
+ )
1071
+ if reduction == "none":
1072
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1073
+
1074
+ if not return_dict:
1075
+ output = (prediction_scores,) + outputs[2:]
1076
+ return ((lm_loss,) + output) if lm_loss is not None else output
1077
+
1078
+ return CausalLMOutputWithCrossAttentions(
1079
+ loss=lm_loss,
1080
+ logits=prediction_scores,
1081
+ past_key_values=outputs.past_key_values,
1082
+ hidden_states=outputs.hidden_states,
1083
+ attentions=outputs.attentions,
1084
+ cross_attentions=outputs.cross_attentions,
1085
+ )
1086
+
1087
+ def prepare_inputs_for_generation(
1088
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1089
+ ):
1090
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1091
+ if attention_mask is None:
1092
+ attention_mask = input_ids.new_ones(input_ids.shape)
1093
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1094
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1095
+
1096
+ # cut decoder_input_ids if past is used
1097
+ if past is not None:
1098
+ input_ids = input_ids[:, -1:]
1099
+
1100
+ return {
1101
+ "input_ids": input_ids,
1102
+ "query_embeds": query_embeds,
1103
+ "attention_mask": attention_mask,
1104
+ "past_key_values": past,
1105
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1106
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1107
+ "is_decoder": True,
1108
+ }
1109
+
1110
+ def _reorder_cache(self, past, beam_idx):
1111
+ reordered_past = ()
1112
+ for layer_past in past:
1113
+ reordered_past += (
1114
+ tuple(
1115
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1116
+ ),
1117
+ )
1118
+ return reordered_past
1119
+
1120
+
1121
+ class BertForMaskedLM(BertPreTrainedModel):
1122
+
1123
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1124
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1125
+
1126
+ def __init__(self, config):
1127
+ super().__init__(config)
1128
+
1129
+ self.bert = BertModel(config, add_pooling_layer=False)
1130
+ self.cls = BertOnlyMLMHead(config)
1131
+
1132
+ self.init_weights()
1133
+
1134
+ def get_output_embeddings(self):
1135
+ return self.cls.predictions.decoder
1136
+
1137
+ def set_output_embeddings(self, new_embeddings):
1138
+ self.cls.predictions.decoder = new_embeddings
1139
+
1140
+ def forward(
1141
+ self,
1142
+ input_ids=None,
1143
+ attention_mask=None,
1144
+ position_ids=None,
1145
+ head_mask=None,
1146
+ query_embeds=None,
1147
+ encoder_hidden_states=None,
1148
+ encoder_attention_mask=None,
1149
+ labels=None,
1150
+ output_attentions=None,
1151
+ output_hidden_states=None,
1152
+ return_dict=None,
1153
+ return_logits=False,
1154
+ is_decoder=False,
1155
+ ):
1156
+ r"""
1157
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1158
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1159
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1160
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1161
+ """
1162
+
1163
+ return_dict = (
1164
+ return_dict if return_dict is not None else self.config.use_return_dict
1165
+ )
1166
+
1167
+ outputs = self.bert(
1168
+ input_ids,
1169
+ attention_mask=attention_mask,
1170
+ position_ids=position_ids,
1171
+ head_mask=head_mask,
1172
+ query_embeds=query_embeds,
1173
+ encoder_hidden_states=encoder_hidden_states,
1174
+ encoder_attention_mask=encoder_attention_mask,
1175
+ output_attentions=output_attentions,
1176
+ output_hidden_states=output_hidden_states,
1177
+ return_dict=return_dict,
1178
+ is_decoder=is_decoder,
1179
+ )
1180
+
1181
+ if query_embeds is not None:
1182
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1183
+ prediction_scores = self.cls(sequence_output)
1184
+
1185
+ if return_logits:
1186
+ return prediction_scores
1187
+
1188
+ masked_lm_loss = None
1189
+ if labels is not None:
1190
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1191
+ masked_lm_loss = loss_fct(
1192
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1193
+ )
1194
+
1195
+ if not return_dict:
1196
+ output = (prediction_scores,) + outputs[2:]
1197
+ return (
1198
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1199
+ )
1200
+
1201
+ return MaskedLMOutput(
1202
+ loss=masked_lm_loss,
1203
+ logits=prediction_scores,
1204
+ hidden_states=outputs.hidden_states,
1205
+ attentions=outputs.attentions,
1206
+ )
visual_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14-336",
3
+ "architectures": [
4
+ "CLIPVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "quick_gelu",
9
+ "hidden_size": 1024,
10
+ "image_size": 588,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 24,
19
+ "patch_size": 14,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.37.0"
23
+ }
visual_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d423ccbd6a035272c301b1aca7167745207625e1d9c3586411553ab8a2a0ca87
3
+ size 609495528
visual_encoder/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 588,
4
+ "width": 588
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 588
26
+ }
27
+ }