ariG23498 HF staff commited on
Commit
ae81e0f
·
1 Parent(s): 9a138bb

chore: adding lolcats configs scrc and src

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml +52 -0
  2. configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml +52 -0
  3. configs/experiment/eval_alpaca_clean.yaml +56 -0
  4. configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml +58 -0
  5. configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml +56 -0
  6. configs/experiment/no_distill_alpaca_clean.yaml +29 -0
  7. configs/model/base_llama3_1_8b.yaml +15 -0
  8. configs/model/base_llama3_8b.yaml +15 -0
  9. configs/model/base_mistral_7b.yaml +15 -0
  10. configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +40 -0
  11. configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +40 -0
  12. configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml +34 -0
  13. configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml +34 -0
  14. configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml +36 -0
  15. configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml +35 -0
  16. configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml +35 -0
  17. configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +39 -0
  18. configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +39 -0
  19. configs/model/distill_llama3_1_8b_lk_t2r.yaml +35 -0
  20. configs/model/distill_llama3_8b_lk_smd_fd64.yaml +29 -0
  21. configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml +33 -0
  22. configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml +33 -0
  23. configs/model/distill_llama3_8b_lk_t2r.yaml +29 -0
  24. configs/model/distill_mistral_7b_lk_smd_fd64.yaml +29 -0
  25. configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml +35 -0
  26. configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml +35 -0
  27. configs/model/distill_mistral_7b_lk_t2r.yaml +29 -0
  28. csrc/__init__.py +6 -0
  29. csrc/causal_attention.cpp +225 -0
  30. csrc/causal_attention.py +77 -0
  31. csrc/causal_attention_cuda.cu +1483 -0
  32. csrc/causal_attention_kv_cuda.cu +1483 -0
  33. csrc/setup.py +53 -0
  34. src/__init__.py +0 -0
  35. src/dataloaders/__init__.py +22 -0
  36. src/dataloaders/alpaca_clean.py +149 -0
  37. src/dataloaders/alpaca_clean_instruct.py +148 -0
  38. src/dataloaders/utils/__init__.py +4 -0
  39. src/dataloaders/utils/llama3.py +62 -0
  40. src/dataloaders/utils/packing.py +80 -0
  41. src/dataloaders/utils/setup.py +123 -0
  42. src/finetune.py +68 -0
  43. src/model/__init__.py +0 -0
  44. src/model/convert_model.py +173 -0
  45. src/model/feature_map.py +306 -0
  46. src/model/linear_attention/__init__.py +23 -0
  47. src/model/linear_attention/linear_attention.py +459 -0
  48. src/model/linear_attention/linear_window_attention_sw.py +339 -0
  49. src/model/linear_attention/linear_window_attention_sw_linear.py +522 -0
  50. src/model/linear_attention/linear_window_attention_sw_long.py +23 -0
configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: alpaca_clean
3
+ dataset_config:
4
+ name: default
5
+ path: yahma/alpaca-cleaned
6
+ chunk_size: 1024 # sequence length for distilling
7
+ concat_data: true
8
+ cache_dir: 'data/alpaca' # Change this to where you want to save
9
+ pretrained_model_config: # will be updated based on model_config
10
+ pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
11
+ cache_dir: '/scratch/'
12
+ preprocess_config: null
13
+
14
+ dataloader:
15
+ batch_size: 1
16
+ num_workers: 2
17
+ drop_last: false
18
+ pin_memory: true
19
+
20
+ optimizer:
21
+ optim: adamw_torch_fused
22
+ lr: 0.01
23
+ weight_decay: 0.0
24
+
25
+ lr_scheduler:
26
+ lr_scheduler_type: reduce_lr_on_plateau
27
+ mode: min
28
+ factor: 0.1
29
+ patience: 10
30
+ min_lr: 0.00001
31
+
32
+ trainer: # HuggingFace Trainer-like arguments
33
+ name: distill_attention_xent_mse
34
+ reverse_kl: false
35
+ mse_factor: 1000
36
+ xent_factor: 0
37
+
38
+ bf16: true
39
+ train_split: train
40
+ val_split: validation
41
+ num_train_epochs: 2
42
+ gradient_accumulation_steps: 8
43
+ seed: 42
44
+ batch_size: 1
45
+ load_best_model_at_end: true
46
+ greater_is_better: false
47
+ metric_for_best_model: distill/eval/loss
48
+ logging_steps: 100
49
+ evaluation_strategy: steps
50
+ max_steps: -1
51
+ eval_steps: 100
52
+ max_eval_batches: null
configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: alpaca_clean
3
+ dataset_config:
4
+ name: default
5
+ path: yahma/alpaca-cleaned
6
+ chunk_size: 1024 # sequence length for distilling
7
+ concat_data: true
8
+ cache_dir: 'data/alpaca' # Change this to where you want to save
9
+ pretrained_model_config: # will be updated based on model_config
10
+ pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B'
11
+ cache_dir: '/data_persistent2/sim_data/llama-3_1-8b/'
12
+ preprocess_config: null
13
+
14
+ dataloader:
15
+ batch_size: 1
16
+ num_workers: 2
17
+ drop_last: false
18
+ pin_memory: true
19
+
20
+ optimizer:
21
+ optim: adamw_torch_fused
22
+ lr: 0.01
23
+ weight_decay: 0.0
24
+
25
+ lr_scheduler:
26
+ lr_scheduler_type: reduce_lr_on_plateau
27
+ mode: min
28
+ factor: 0.1
29
+ patience: 10
30
+ min_lr: 0.00001
31
+
32
+ trainer: # HuggingFace Trainer-like arguments
33
+ name: distill_attention_xent_mse
34
+ reverse_kl: false
35
+ mse_factor: 1000
36
+ xent_factor: 1
37
+
38
+ bf16: true
39
+ train_split: train
40
+ val_split: validation
41
+ num_train_epochs: 2
42
+ gradient_accumulation_steps: 8
43
+ seed: 42
44
+ batch_size: 1
45
+ load_best_model_at_end: true
46
+ greater_is_better: false
47
+ metric_for_best_model: distill/eval/loss
48
+ logging_steps: 100
49
+ evaluation_strategy: steps
50
+ max_steps: -1
51
+ eval_steps: 100
52
+ max_eval_batches: null
configs/experiment/eval_alpaca_clean.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: alpaca_clean
3
+ dataset_config:
4
+ name: alpaca
5
+ path: yahma/alpaca-cleaned
6
+ chunk_size: 1024 # sequence length for distilling
7
+ concat_data: true
8
+ cache_dir: 'data/alpaca' # Change this to where you want to save
9
+ pretrained_model_config:
10
+ pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
11
+ cache_dir: '/scratch/'
12
+ preprocess_config: null
13
+
14
+ dataloader:
15
+ batch_size: 1
16
+ num_workers: 2
17
+ drop_last: false
18
+ pin_memory: true
19
+
20
+ optimizer:
21
+ optim: adamw_torch_fused
22
+ lr: 1e-4
23
+ weight_decay: 0.0
24
+
25
+ lr_scheduler:
26
+ lr_scheduler_type: reduce_lr_on_plateau
27
+ mode: min
28
+ factor: 0.1
29
+ patience: 10
30
+ min_lr: 0.00001
31
+
32
+ trainer: # HuggingFace Trainer-like arguments
33
+ name: finetune_seq2seq
34
+ bf16: true
35
+ train_split: train
36
+ val_split: test
37
+ num_train_epochs: 2
38
+ gradient_accumulation_steps: 8
39
+ seed: 42
40
+ batch_size: 1
41
+ load_best_model_at_end: true
42
+ greater_is_better: true
43
+ metric_for_best_model: eval/rouge/geometric_mean
44
+ logging_steps: 100
45
+ evaluation_strategy: steps
46
+ max_steps: -1
47
+ eval_steps: 100
48
+ max_eval_batches: null
49
+
50
+ finetune:
51
+ method: lora
52
+ kwargs:
53
+ r: 8
54
+ lora_alpha: 16
55
+ lora_dropout: 0 # 0.05
56
+ target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: alpaca_clean
3
+ dataset_config:
4
+ name: default
5
+ path: yahma/alpaca-cleaned
6
+ chunk_size: 1024
7
+ concat_data: true
8
+ cache_dir: "data/alpaca"
9
+ pretrained_model_config:
10
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
11
+ cache_dir: "/data_persistent2/sim_data/"
12
+ preprocess_config: null
13
+
14
+ dataloader:
15
+ batch_size: 1
16
+ num_workers: 2
17
+ drop_last: false
18
+ pin_memory: true
19
+
20
+ optimizer:
21
+ optim: adamw_torch_fused
22
+ lr: 1e-4
23
+ weight_decay: 0.0
24
+
25
+ lr_scheduler:
26
+ lr_scheduler_type: reduce_lr_on_plateau
27
+ mode: min
28
+ factor: 0.1
29
+ patience: 10
30
+ min_lr: 0.00001
31
+
32
+ trainer: # HuggingFace Trainer-like arguments
33
+ name: default_lm
34
+ bf16: true
35
+ train_split: train
36
+ val_split: validation
37
+ num_train_epochs: 2
38
+ gradient_accumulation_steps: 8
39
+ seed: 42
40
+ batch_size: 1
41
+ load_best_model_at_end: true
42
+ greater_is_better: false
43
+ metric_for_best_model: eval/loss # eval/rouge/geometric_mean
44
+ logging_steps: 100
45
+ evaluation_strategy: steps
46
+ max_steps: -1
47
+ eval_steps: 100
48
+ max_eval_batches: null
49
+ num_save_ckpt_steps: 200
50
+
51
+ finetune:
52
+ method: lora
53
+ kwargs:
54
+ r: 8
55
+ lora_alpha: 16
56
+ lora_dropout: 0 # 0.05
57
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
58
+ trainable_weights: ['feature_map_q.mlp.layer', 'feature_map_k.mlp.layer', 'window_factors']
configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: alpaca_clean
3
+ dataset_config:
4
+ name: default
5
+ path: yahma/alpaca-cleaned
6
+ chunk_size: 1024
7
+ concat_data: true
8
+ cache_dir: "data/alpaca"
9
+ pretrained_model_config:
10
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
11
+ cache_dir: "/scratch/"
12
+ preprocess_config: null
13
+
14
+ dataloader:
15
+ batch_size: 1
16
+ num_workers: 2
17
+ drop_last: false
18
+ pin_memory: true
19
+
20
+ optimizer:
21
+ optim: adamw_torch_fused
22
+ lr: 1e-4
23
+ weight_decay: 0.0
24
+
25
+ lr_scheduler:
26
+ lr_scheduler_type: reduce_lr_on_plateau
27
+ mode: min
28
+ factor: 0.1
29
+ patience: 10
30
+ min_lr: 0.00001
31
+
32
+ trainer: # HuggingFace Trainer-like arguments
33
+ name: default_lm
34
+ bf16: true
35
+ train_split: train
36
+ val_split: validation
37
+ num_train_epochs: 2
38
+ gradient_accumulation_steps: 8
39
+ seed: 42
40
+ batch_size: 1
41
+ load_best_model_at_end: true
42
+ greater_is_better: false
43
+ metric_for_best_model: eval/loss # eval/rouge/geometric_mean
44
+ logging_steps: 100
45
+ evaluation_strategy: steps
46
+ max_steps: -1
47
+ eval_steps: 100
48
+ max_eval_batches: null
49
+
50
+ finetune:
51
+ method: lora
52
+ kwargs:
53
+ r: 8
54
+ lora_alpha: 16
55
+ lora_dropout: 0 # 0.05
56
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
configs/experiment/no_distill_alpaca_clean.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: alpaca_clean
3
+ dataset_config:
4
+ name: alpaca
5
+ path: yahma/alpaca-cleaned
6
+ chunk_size: 1024 # sequence length for distilling
7
+ concat_data: true
8
+ cache_dir: 'data/alpaca' # Change this to where you want to save
9
+ pretrained_model_config:
10
+ pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
11
+ cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1'
12
+ preprocess_config: null
13
+
14
+ dataloader:
15
+ batch_size: 1
16
+ num_workers: 2
17
+ drop_last: false
18
+ pin_memory: true
19
+
20
+ optimizer:
21
+ optim: adamw_torch_fused
22
+ lr: 0.01
23
+ weight_decay: 0.0
24
+
25
+ lr_scheduler:
26
+ lr_scheduler_type: none
27
+
28
+ trainer: # HuggingFace Trainer-like arguments
29
+ name: null
configs/model/base_llama3_1_8b.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B'
4
+ cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 500000.0
13
+
14
+ attention:
15
+ attention_type: softmax
configs/model/base_llama3_8b.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
4
+ cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 500000.0
13
+
14
+ attention:
15
+ attention_type: softmax
configs/model/base_mistral_7b.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
4
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 10000.0
13
+
14
+ attention:
15
+ attention_type: softmax
configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experimental config for chunked linear attention
2
+ name: llama
3
+ model:
4
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
5
+ cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
6
+ return_dict: true
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ device_map: auto
10
+ low_cpu_mem_usage: true
11
+ torch_dtype: bfloat16
12
+ attn_implementation: flash_attention_2
13
+ rope_theta: 500000.0
14
+ rope_scaling:
15
+ factor: 8.0
16
+ low_freq_factor: 1.0
17
+ high_freq_factor: 4.0
18
+ original_max_position_embeddings: 8192
19
+ rope_type: llama3
20
+
21
+ attention:
22
+ attention_type: lolcats_long_llama_window_sw
23
+ state_chunk_len: 1024
24
+ window_size: 64
25
+ affine_attention_factors: false
26
+ init_window_factor: -2.1972245773362196
27
+ feature_map: softmax_dim
28
+ feature_map_kwargs:
29
+ eps: 1e-12
30
+ # mlp: null # to set
31
+ fullspace: true
32
+ layer_idx: null # to set
33
+ learned_kernel: untied_head_einsum
34
+ learned_kernel_kwargs:
35
+ feature_dim: 64
36
+ skip_connection: false
37
+ bias: false
38
+ zero_init: false
39
+ tie_qk_kernels: false
40
+ train_qk: false
configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experimental config for chunked linear attention
2
+ name: llama
3
+ model:
4
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
5
+ cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
6
+ return_dict: true
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ device_map: auto
10
+ low_cpu_mem_usage: true
11
+ torch_dtype: bfloat16
12
+ attn_implementation: flash_attention_2
13
+ rope_theta: 500000.0
14
+ rope_scaling:
15
+ factor: 8.0
16
+ low_freq_factor: 1.0
17
+ high_freq_factor: 4.0
18
+ original_max_position_embeddings: 8192
19
+ rope_type: llama3
20
+
21
+ attention:
22
+ attention_type: lolcats_long_llama_window_tk
23
+ state_chunk_len: 1024
24
+ window_size: 64
25
+ affine_attention_factors: false
26
+ init_window_factor: -2.1972245773362196
27
+ feature_map: softmax_dim
28
+ feature_map_kwargs:
29
+ eps: 1e-12
30
+ # mlp: null # to set
31
+ fullspace: true
32
+ layer_idx: null # to set
33
+ learned_kernel: untied_head_einsum
34
+ learned_kernel_kwargs:
35
+ feature_dim: 64
36
+ skip_connection: false
37
+ bias: false
38
+ zero_init: false
39
+ tie_qk_kernels: false
40
+ train_qk: false
configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experimental config for chunked linear attention
2
+ name: llama
3
+ model:
4
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
5
+ cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
6
+ return_dict: true
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ device_map: auto
10
+ low_cpu_mem_usage: true
11
+ torch_dtype: bfloat16
12
+ attn_implementation: flash_attention_2
13
+ rope_theta: 500000.0
14
+
15
+ attention:
16
+ attention_type: lolcats_long_llama_window_sw
17
+ state_chunk_len: 1024
18
+ window_size: 64
19
+ affine_attention_factors: false
20
+ init_window_factor: -2.1972245773362196
21
+ feature_map: softmax_dim
22
+ feature_map_kwargs:
23
+ eps: 1e-12
24
+ # mlp: null # to set
25
+ fullspace: true
26
+ layer_idx: null # to set
27
+ learned_kernel: untied_head_einsum
28
+ learned_kernel_kwargs:
29
+ feature_dim: 64
30
+ skip_connection: false
31
+ bias: false
32
+ zero_init: false
33
+ tie_qk_kernels: false
34
+ train_qk: false
configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experimental config for chunked linear attention
2
+ name: llama
3
+ model:
4
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
5
+ cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
6
+ return_dict: true
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ device_map: auto
10
+ low_cpu_mem_usage: true
11
+ torch_dtype: bfloat16
12
+ attn_implementation: flash_attention_2
13
+ rope_theta: 500000.0
14
+
15
+ attention:
16
+ attention_type: lolcats_long_llama_window_tk
17
+ state_chunk_len: 1024
18
+ window_size: 64
19
+ affine_attention_factors: false
20
+ init_window_factor: -2.1972245773362196
21
+ feature_map: softmax_dim
22
+ feature_map_kwargs:
23
+ eps: 1e-12
24
+ # mlp: null # to set
25
+ fullspace: true
26
+ layer_idx: null # to set
27
+ learned_kernel: untied_head_einsum
28
+ learned_kernel_kwargs:
29
+ feature_dim: 64
30
+ skip_connection: false
31
+ bias: false
32
+ zero_init: false
33
+ tie_qk_kernels: false
34
+ train_qk: false
configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experimental config for chunked linear attention
2
+ name: llama
3
+ model:
4
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
5
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
6
+ return_dict: true
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ device_map: auto
10
+ low_cpu_mem_usage: true
11
+ torch_dtype: bfloat16
12
+ attn_implementation: flash_attention_2 # eager # so we can load attention weights
13
+ rope_theta: 10000.0
14
+
15
+ attention:
16
+ attention_type: lolcats_long_llama_window_sw
17
+ state_chunk_len: 512 # 1024
18
+ window_size: 64
19
+ affine_attention_factors: false
20
+ init_window_factor: -2.1972245773362196
21
+ train_window_factor: true
22
+ train_attention_weights: false
23
+ feature_map: softmax_dim
24
+ feature_map_kwargs:
25
+ eps: 1e-12
26
+ # mlp: null # to set
27
+ fullspace: true
28
+ layer_idx: null # to set
29
+ learned_kernel: untied_head_einsum
30
+ learned_kernel_kwargs:
31
+ feature_dim: 64
32
+ skip_connection: false
33
+ bias: false
34
+ zero_init: false
35
+ tie_qk_kernels: false
36
+ train_qk: false
configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
4
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2 # eager # so we can load attention weights
12
+ rope_theta: 10000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_long_llama_window_tk
16
+ state_chunk_len: 512 # 1024
17
+ window_size: 64
18
+ affine_attention_factors: false
19
+ init_window_factor: -2.1972245773362196
20
+ train_window_factor: true
21
+ train_attention_weights: false
22
+ feature_map: softmax_dim
23
+ feature_map_kwargs:
24
+ eps: 1e-12
25
+ # mlp: null # to set
26
+ fullspace: true
27
+ layer_idx: null # to set
28
+ learned_kernel: untied_head_einsum
29
+ learned_kernel_kwargs:
30
+ feature_dim: 64
31
+ skip_connection: false
32
+ bias: false
33
+ zero_init: false
34
+ tie_qk_kernels: false
35
+ train_qk: false
configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
4
+ cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: eager
12
+ rope_theta: 500000.0
13
+ rope_scaling:
14
+ factor: 8.0
15
+ low_freq_factor: 1.0
16
+ high_freq_factor: 4.0
17
+ original_max_position_embeddings: 8192
18
+ rope_type: llama3
19
+
20
+ attention:
21
+ attention_type: lolcats_llama
22
+ feature_map: softmax_dim
23
+ feature_map_kwargs:
24
+ eps: 1e-12
25
+ # mlp: null # to set
26
+ fullspace: true
27
+ layer_idx: null # to set
28
+ learned_kernel: untied_head_einsum
29
+ learned_kernel_kwargs:
30
+ feature_dim: 64
31
+ skip_connection: false
32
+ bias: false
33
+ zero_init: false
34
+ tie_qk_kernels: false
35
+ train_qk: false
configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
4
+ cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: eager
12
+ rope_theta: 500000.0
13
+ rope_scaling:
14
+ factor: 8.0
15
+ low_freq_factor: 1.0
16
+ high_freq_factor: 4.0
17
+ original_max_position_embeddings: 8192
18
+ rope_type: llama3
19
+
20
+ attention:
21
+ attention_type: lolcats_llama_window_sw
22
+ state_chunk_len: 1024
23
+ window_size: 64
24
+ affine_attention_factors: false
25
+ init_window_factor: -2.1972245773362196
26
+ feature_map: softmax_dim
27
+ feature_map_kwargs:
28
+ eps: 1e-12
29
+ # mlp: null # to set
30
+ fullspace: true
31
+ layer_idx: null # to set
32
+ learned_kernel: untied_head_einsum
33
+ learned_kernel_kwargs:
34
+ feature_dim: 64
35
+ skip_connection: false
36
+ bias: false
37
+ zero_init: false
38
+ tie_qk_kernels: false
39
+ train_qk: false
configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
4
+ cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: eager
12
+ rope_theta: 500000.0
13
+ rope_scaling:
14
+ factor: 8.0
15
+ low_freq_factor: 1.0
16
+ high_freq_factor: 4.0
17
+ original_max_position_embeddings: 8192
18
+ rope_type: llama3
19
+
20
+ attention:
21
+ attention_type: lolcats_llama_window_tk
22
+ state_chunk_len: 1024
23
+ window_size: 64
24
+ affine_attention_factors: false
25
+ init_window_factor: -2.1972245773362196
26
+ feature_map: softmax_dim
27
+ feature_map_kwargs:
28
+ eps: 1e-12
29
+ # mlp: null # to set
30
+ fullspace: true
31
+ layer_idx: null # to set
32
+ learned_kernel: untied_head_einsum
33
+ learned_kernel_kwargs:
34
+ feature_dim: 64
35
+ skip_connection: false
36
+ bias: false
37
+ zero_init: false
38
+ tie_qk_kernels: false
39
+ train_qk: false
configs/model/distill_llama3_1_8b_lk_t2r.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
4
+ cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: eager
12
+ rope_theta: 500000.0
13
+ rope_scaling:
14
+ factor: 8.0
15
+ low_freq_factor: 1.0
16
+ high_freq_factor: 4.0
17
+ original_max_position_embeddings: 8192
18
+ rope_type: llama3
19
+
20
+ attention:
21
+ attention_type: lolcats_llama
22
+ feature_map: relu
23
+ feature_map_kwargs:
24
+ eps: 1e-12
25
+ # mlp: null # to set
26
+ fullspace: true
27
+ layer_idx: null # to set
28
+ learned_kernel: untied_head_einsum
29
+ learned_kernel_kwargs:
30
+ feature_dim: 128
31
+ skip_connection: false
32
+ bias: true
33
+ zero_init: false
34
+ tie_qk_kernels: false
35
+ train_qk: false
configs/model/distill_llama3_8b_lk_smd_fd64.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
4
+ cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 500000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama
16
+ feature_map: softmax_dim
17
+ feature_map_kwargs:
18
+ eps: 1e-12
19
+ # mlp: null # to set
20
+ fullspace: true
21
+ layer_idx: null # to set
22
+ learned_kernel: untied_head_einsum
23
+ learned_kernel_kwargs:
24
+ feature_dim: 64
25
+ skip_connection: false
26
+ bias: false
27
+ zero_init: false
28
+ tie_qk_kernels: false
29
+ train_qk: false
configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
4
+ cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 500000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama_window_sw
16
+ state_chunk_len: 1024
17
+ window_size: 64
18
+ affine_attention_factors: false
19
+ init_window_factor: -2.1972245773362196
20
+ feature_map: softmax_dim
21
+ feature_map_kwargs:
22
+ eps: 1e-12
23
+ # mlp: null # to set
24
+ fullspace: true
25
+ layer_idx: null # to set
26
+ learned_kernel: untied_head_einsum
27
+ learned_kernel_kwargs:
28
+ feature_dim: 64
29
+ skip_connection: false
30
+ bias: false
31
+ zero_init: false
32
+ tie_qk_kernels: false
33
+ train_qk: false
configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
4
+ cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 500000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama_window_tk
16
+ state_chunk_len: 1024
17
+ window_size: 64
18
+ affine_attention_factors: false
19
+ init_window_factor: -2.1972245773362196
20
+ feature_map: softmax_dim
21
+ feature_map_kwargs:
22
+ eps: 1e-12
23
+ # mlp: null # to set
24
+ fullspace: true
25
+ layer_idx: null # to set
26
+ learned_kernel: untied_head_einsum
27
+ learned_kernel_kwargs:
28
+ feature_dim: 64
29
+ skip_connection: false
30
+ bias: false
31
+ zero_init: false
32
+ tie_qk_kernels: false
33
+ train_qk: false
configs/model/distill_llama3_8b_lk_t2r.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
4
+ cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2
12
+ rope_theta: 500000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama
16
+ feature_map: relu
17
+ feature_map_kwargs:
18
+ eps: 1e-12
19
+ # mlp: null # to set
20
+ fullspace: true
21
+ layer_idx: null # to set
22
+ learned_kernel: untied_head_einsum
23
+ learned_kernel_kwargs:
24
+ feature_dim: 128
25
+ skip_connection: false
26
+ bias: true
27
+ zero_init: false
28
+ tie_qk_kernels: false
29
+ train_qk: false
configs/model/distill_mistral_7b_lk_smd_fd64.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
4
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2 # eager # so we can load attention weights
12
+ rope_theta: 10000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama
16
+ feature_map: softmax_dim
17
+ feature_map_kwargs:
18
+ eps: 1e-12
19
+ # mlp: null # to set
20
+ fullspace: true
21
+ layer_idx: null # to set
22
+ learned_kernel: untied_head_einsum
23
+ learned_kernel_kwargs:
24
+ feature_dim: 64
25
+ skip_connection: false
26
+ bias: false
27
+ zero_init: false
28
+ tie_qk_kernels: false
29
+ train_qk: false
configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
4
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2 # eager # so we can load attention weights
12
+ rope_theta: 10000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama_window_sw
16
+ state_chunk_len: 512 # 1024
17
+ window_size: 64
18
+ affine_attention_factors: false
19
+ init_window_factor: -2.1972245773362196
20
+ train_window_factor: true
21
+ train_attention_weights: false
22
+ feature_map: softmax_dim
23
+ feature_map_kwargs:
24
+ eps: 1e-12
25
+ # mlp: null # to set
26
+ fullspace: true
27
+ layer_idx: null # to set
28
+ learned_kernel: untied_head_einsum
29
+ learned_kernel_kwargs:
30
+ feature_dim: 64
31
+ skip_connection: false
32
+ bias: false
33
+ zero_init: false
34
+ tie_qk_kernels: false
35
+ train_qk: false
configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
4
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2 # eager # so we can load attention weights
12
+ rope_theta: 10000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama_window_tk
16
+ state_chunk_len: 512 # 1024
17
+ window_size: 64
18
+ affine_attention_factors: false
19
+ init_window_factor: -2.1972245773362196
20
+ train_window_factor: true
21
+ train_attention_weights: false
22
+ feature_map: softmax_dim
23
+ feature_map_kwargs:
24
+ eps: 1e-12
25
+ # mlp: null # to set
26
+ fullspace: true
27
+ layer_idx: null # to set
28
+ learned_kernel: untied_head_einsum
29
+ learned_kernel_kwargs:
30
+ feature_dim: 64
31
+ skip_connection: false
32
+ bias: false
33
+ zero_init: false
34
+ tie_qk_kernels: false
35
+ train_qk: false
configs/model/distill_mistral_7b_lk_t2r.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: llama
2
+ model:
3
+ pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
4
+ cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
5
+ return_dict: true
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ device_map: auto
9
+ low_cpu_mem_usage: true
10
+ torch_dtype: bfloat16
11
+ attn_implementation: flash_attention_2 # eager # so we can load attention weights
12
+ rope_theta: 10000.0
13
+
14
+ attention:
15
+ attention_type: lolcats_llama
16
+ feature_map: relu
17
+ feature_map_kwargs:
18
+ eps: 1e-12
19
+ # mlp: null # to set
20
+ fullspace: true
21
+ layer_idx: null # to set
22
+ learned_kernel: untied_head_einsum
23
+ learned_kernel_kwargs:
24
+ feature_dim: 128
25
+ skip_connection: false
26
+ bias: true
27
+ zero_init: false
28
+ tie_qk_kernels: false
29
+ train_qk: false
csrc/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+ from .causal_attention import causal_dot_product
csrc/causal_attention.cpp ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ // Written by Angelos Katharopoulos <[email protected]>,
4
+ // Apoorv Vyas <[email protected]>
5
+ //
6
+
7
+ #include <torch/extension.h>
8
+
9
+
10
+ /**
11
+ * Compute a*b^T and save it into out.
12
+ *
13
+ * a \in R^A
14
+ * b \in R^B
15
+ */
16
+ inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
17
+ for (int i=0; i<A; i++) {
18
+ float * bi = b;
19
+ for (int j=0; j<B; j++) {
20
+ *out += (*a) * (*bi);
21
+ out++;
22
+ bi++;
23
+ }
24
+ a++;
25
+ }
26
+ }
27
+
28
+
29
+ /**
30
+ * Implement a vector matrix product v*m and save it into out.
31
+ *
32
+ * v \in R^A
33
+ * m \in R^{AxB}
34
+ */
35
+ inline void vm_dot(float *v, float *m, float *out, int A, int B) {
36
+ // TODO: Consider removing the zeroing part and assuming out already
37
+ // contains 0s
38
+ for (int i=0; i<B; i++) {
39
+ out[i] = 0;
40
+ }
41
+
42
+ for (int i=0; i<A; i++) {
43
+ float *oi = out;
44
+ for (int j=0; j<B; j++) {
45
+ *oi += (*v) * (*m);
46
+ oi++;
47
+ m++;
48
+ }
49
+ v++;
50
+ }
51
+ }
52
+
53
+
54
+ /**
55
+ * Implement a vector transposed-matrix product and save it into out.
56
+ *
57
+ * v \in R^B
58
+ * m \in R^{AxB}
59
+ */
60
+ inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
61
+ for (int i=0; i<A; i++) {
62
+ float *vi = v;
63
+ float s = 0;
64
+ for (int j=0; j<B; j++) {
65
+ s += (*vi) * (*m);
66
+ vi++;
67
+ m++;
68
+ }
69
+ // TODO: Should we be aggregating? See the comment on vm_dot.
70
+ *out = s;
71
+ out++;
72
+ }
73
+ }
74
+
75
+
76
+ /**
77
+ * Compute the causally masked dot products of queries, keys and values.
78
+ *
79
+ * Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
80
+ * computation is done efficiently by changing the order of the dot products.
81
+ */
82
+ void causal_dot_product(
83
+ const torch::Tensor queries,
84
+ const torch::Tensor keys,
85
+ const torch::Tensor values,
86
+ torch::Tensor product
87
+ ) {
88
+ // Extract some shapes
89
+ int N = queries.size(0);
90
+ int H = queries.size(1);
91
+ int L = queries.size(2);
92
+ int E = queries.size(3);
93
+ int M = values.size(3);
94
+
95
+ // Create accessors for all the arguments
96
+ auto qa = queries.accessor<float, 4>();
97
+ auto ka = keys.accessor<float, 4>();
98
+ auto va = values.accessor<float, 4>();
99
+ auto pa = product.accessor<float, 4>();
100
+
101
+ #pragma omp parallel for collapse(2)
102
+ for (int n=0; n<N; n++) {
103
+ for (int h=0; h<H; h++) {
104
+ auto kv = torch::zeros({E, M}, queries.options());
105
+ float *kvp = kv.data_ptr<float>();
106
+ for (int l=0; l<L; l++) {
107
+ vvt_dot(
108
+ &ka[n][h][l][0],
109
+ &va[n][h][l][0],
110
+ kvp,
111
+ E,
112
+ M
113
+ );
114
+ vm_dot(
115
+ &qa[n][h][l][0],
116
+ kvp,
117
+ &pa[n][h][l][0],
118
+ E,
119
+ M
120
+ );
121
+ }
122
+ }
123
+ }
124
+ }
125
+
126
+
127
+ /**
128
+ * Compute the gradients of queries, keys and values given the gradient of the
129
+ * causal_dot_product output.
130
+ *
131
+ * Make sure that everything is computed in O(N D^2) complexity.
132
+ */
133
+ void causal_dot_backward(
134
+ const torch::Tensor queries,
135
+ const torch::Tensor keys,
136
+ const torch::Tensor values,
137
+ const torch::Tensor grad_out,
138
+ torch::Tensor grad_queries,
139
+ torch::Tensor grad_keys,
140
+ torch::Tensor grad_values
141
+ ) {
142
+ // Extract some shapes
143
+ int N = queries.size(0);
144
+ int H = queries.size(1);
145
+ int L = queries.size(2);
146
+ int E = queries.size(3);
147
+ int M = values.size(3);
148
+
149
+ // Create accessors for all the arguments
150
+ auto qa = queries.accessor<float, 4>();
151
+ auto ka = keys.accessor<float, 4>();
152
+ auto va = values.accessor<float, 4>();
153
+ auto ga = grad_out.accessor<float, 4>();
154
+ auto gqa = grad_queries.accessor<float, 4>();
155
+ auto gka = grad_keys.accessor<float, 4>();
156
+ auto gva = grad_values.accessor<float, 4>();
157
+
158
+ #pragma omp parallel for collapse(2)
159
+ for (int n=0; n<N; n++) {
160
+ for (int h=0; h<H; h++) {
161
+ auto kv = torch::zeros({E, M}, queries.options());
162
+ float *kvp = kv.data_ptr<float>();
163
+
164
+ // Compute the gradient wrt the queries
165
+ for (int l=0; l<L; l++) {
166
+ vvt_dot(
167
+ &ka[n][h][l][0],
168
+ &va[n][h][l][0],
169
+ kvp,
170
+ E,
171
+ M
172
+ );
173
+ vmt_dot(
174
+ &ga[n][h][l][0],
175
+ kvp,
176
+ &gqa[n][h][l][0],
177
+ E,
178
+ M
179
+ );
180
+ }
181
+
182
+ // Compute the gradient wrt the keys and values
183
+ kv.zero_();
184
+ for (int l=L-1; l>=0; l--) {
185
+ vvt_dot(
186
+ &qa[n][h][l][0],
187
+ &ga[n][h][l][0],
188
+ kvp,
189
+ E,
190
+ M
191
+ );
192
+ vmt_dot(
193
+ &va[n][h][l][0],
194
+ kvp,
195
+ &gka[n][h][l][0],
196
+ E,
197
+ M
198
+ );
199
+ vm_dot(
200
+ &ka[n][h][l][0],
201
+ kvp,
202
+ &gva[n][h][l][0],
203
+ E,
204
+ M
205
+ );
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+
212
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
213
+ m.def(
214
+ "causal_dot_product",
215
+ &causal_dot_product,
216
+ "Compute the weighted sum of values but attending only to previous "
217
+ "values."
218
+ );
219
+ m.def(
220
+ "causal_dot_backward",
221
+ &causal_dot_backward,
222
+ "Compute the gradient of queries, keys and values given the gradient "
223
+ "of causal_dot_product."
224
+ );
225
+ }
csrc/causal_attention.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ import torch
8
+
9
+ try:
10
+ from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
11
+ from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
12
+ except ImportError as e:
13
+ print(e)
14
+ causal_dot_product_cuda = causal_dot_backward_cuda = None
15
+
16
+
17
+ class CausalDotProduct(torch.autograd.Function):
18
+ """Compute the weighted sum of values but attending only to previous
19
+ values."""
20
+ dot = {
21
+ # "cpu": causal_dot_product_cpu,
22
+ "cuda": causal_dot_product_cuda
23
+ }
24
+ dot_backward = {
25
+ # "cpu": causal_dot_backward_cpu,
26
+ "cuda": causal_dot_backward_cuda
27
+ }
28
+
29
+ @staticmethod
30
+ def forward(ctx, Q, K, V):
31
+ # Save the inputs for the gradient computation
32
+ ctx.save_for_backward(Q, K, V)
33
+
34
+ # Create the output tensor
35
+ device = Q.device
36
+ N, H, L, _ = Q.shape
37
+ _, _, _, M = V.shape
38
+ product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
39
+
40
+ # Actually perform the dot product
41
+ CausalDotProduct.dot[device.type](
42
+ Q.data,
43
+ K.data,
44
+ V.data,
45
+ product
46
+ )
47
+ # breakpoint()
48
+ # CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
49
+
50
+ return product
51
+
52
+ @staticmethod
53
+ def backward(ctx, grad_out):
54
+ # Extract the saved tensors
55
+ Q, K, V = ctx.saved_tensors
56
+
57
+ # Allocate memory for the gradients
58
+ grad_Q = torch.zeros_like(Q)
59
+ grad_K = torch.zeros_like(K)
60
+ grad_V = torch.zeros_like(V)
61
+
62
+ # Actually compute the gradients
63
+ CausalDotProduct.dot_backward[Q.device.type](
64
+ Q.data,
65
+ K.data,
66
+ V.data,
67
+ grad_out,
68
+ grad_Q,
69
+ grad_K,
70
+ grad_V
71
+ )
72
+
73
+ return grad_Q, grad_K, grad_V
74
+
75
+
76
+ # Alias the autograd functions to python style snake case naming
77
+ causal_dot_product = CausalDotProduct.apply
csrc/causal_attention_cuda.cu ADDED
@@ -0,0 +1,1483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ // Written by Angelos Katharopoulos <[email protected]>,
4
+ // Apoorv Vyas <[email protected]>
5
+ //
6
+
7
+ //
8
+ // For modifications made inside namespace nvidia (authored by jdemouth):
9
+ //
10
+ // Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
11
+ //
12
+ // Permission is hereby granted, free of charge, to any person obtaining a copy of
13
+ // this software and associated documentation files (the "Software"), to deal in
14
+ // the Software without restriction, including without limitation the rights to
15
+ // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
16
+ // the Software, and to permit persons to whom the Software is furnished to do so,
17
+ // subject to the following conditions:
18
+ //
19
+ // The above copyright notice and this permission notice shall be included in all
20
+ // copies or substantial portions of the Software.
21
+ //
22
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
24
+ // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
25
+ // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
26
+ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
27
+ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28
+ //
29
+
30
+ #include <torch/extension.h>
31
+ #include <assert.h>
32
+ #include <stdio.h>
33
+
34
+ #define ENABLE_NVIDIA_OPTIMIZATIONS
35
+
36
+ #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
37
+ namespace nvidia {
38
+
39
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ constexpr int THREADS_PER_WARP = 32;
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ static inline __device__ __host__ int div_up(int m, int n) {
50
+ return (m + n-1) / n;
51
+ }
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ static inline __device__ __host__ int round_up(int m, int n) {
56
+ return div_up(m, n) * n;
57
+ }
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ template< typename T >
62
+ struct Lmha_params {
63
+
64
+ // The output buffer. Dimensions [B, H, L, M].
65
+ T *out;
66
+
67
+ // The input Qs. Dimensions [B, H, L, E].
68
+ const T *q;
69
+ // The input Ks. Dimensions [B, H, L, E].
70
+ const T *k;
71
+ // The input Vs. Dimensions [B, H, L, M].
72
+ const T *v;
73
+
74
+ // The different dimensions.
75
+ int B, L, H, E, M;
76
+
77
+ // The strides for the different tensors.
78
+ int q_stride_B, q_stride_H, q_stride_L;
79
+ int k_stride_B, k_stride_H, k_stride_L;
80
+ int v_stride_B, v_stride_H, v_stride_L;
81
+ int o_stride_B, o_stride_H, o_stride_L;
82
+ };
83
+
84
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
85
+
86
+ template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
87
+ __global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
88
+ void lmha_low_occupancy_kernel(Lmha_params<float> params) {
89
+
90
+ // The number of threads per block.
91
+ constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
92
+ // The number of rows per thread.
93
+ constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
94
+ // The number of steps per iteration.
95
+ constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;
96
+
97
+ // Make sure E is a multiple of the warp size.
98
+ static_assert(E % THREADS_PER_WARP == 0, "");
99
+
100
+ // Shared memory to store V/O.
101
+ __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
102
+ // Shared memory buffer to performance the reductions.
103
+ __shared__ float smem_reds[E * WARPS];
104
+
105
+ // The sequence processed by that block.
106
+ const int bi = blockIdx.z;
107
+ // The head processed by that block.
108
+ const int hi = blockIdx.y;
109
+ // The hidden cell in the V/output buffers.
110
+ const int vi = blockIdx.x;
111
+
112
+ // The linear index of the thread.
113
+ const int tidx = threadIdx.x;
114
+
115
+ // Decompose the block in warp/lane.
116
+ const int warp = tidx / THREADS_PER_WARP;
117
+ const int lane = tidx % THREADS_PER_WARP;
118
+
119
+ // The base offset loaded by the thread in Q and K.
120
+ int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
121
+ int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;
122
+
123
+ // If we walk backward, account for the extra offset.
124
+ if( GO_BACKWARD ) {
125
+ offset_q += (params.L-1)*params.q_stride_L;
126
+ offset_k += (params.L-1)*params.k_stride_L;
127
+ }
128
+
129
+ // Position the warp at the beginning of the proper timestep.
130
+ if( GO_BACKWARD ) {
131
+ offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
132
+ offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
133
+ } else {
134
+ offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
135
+ offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
136
+ }
137
+
138
+ // Determine the base pointers for Q and K.
139
+ const float *ptr_q = &params.q[offset_q];
140
+ const float *ptr_k = &params.k[offset_k];
141
+
142
+ // Is a given row valid?
143
+ int valid_qk[ROWS_PER_THREAD];
144
+ #pragma unroll
145
+ for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
146
+ valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
147
+ }
148
+
149
+ // The offset to the position loaded by the thread in V.
150
+ int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
151
+ int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;
152
+
153
+ // If we walk backward, account for the extra offset.
154
+ if( GO_BACKWARD ) {
155
+ offset_v += (params.L-1)*params.v_stride_L;
156
+ offset_o += (params.L-1)*params.o_stride_L;
157
+ }
158
+
159
+ // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
160
+ if( GO_BACKWARD ) {
161
+ offset_v -= tidx*params.v_stride_L;
162
+ offset_o -= tidx*params.o_stride_L;
163
+ } else {
164
+ offset_v += tidx*params.v_stride_L;
165
+ offset_o += tidx*params.o_stride_L;
166
+ }
167
+
168
+ // Determine the base pointer for V.
169
+ const float *ptr_v = &params.v[offset_v];
170
+ // The output pointer.
171
+ float *ptr_o = &params.out[offset_o];
172
+
173
+ // The running KVs.
174
+ float running_kv[ROWS_PER_THREAD];
175
+ #pragma unroll
176
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
177
+ running_kv[ri] = 0.f;
178
+ }
179
+
180
+ // Iterate over the timesteps. TODO: Use params.loop_count!!!
181
+ for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {
182
+
183
+ // Each thread loads a matrix of elements.
184
+ float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];
185
+
186
+ // Trigger the memory loads for Q and K.
187
+ #pragma unroll
188
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
189
+ #pragma unroll
190
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
191
+
192
+ // For Q/K, each warp loads from various timesteps.
193
+ int ti = iter + warp*COLS_PER_THREAD;
194
+ if( GO_BACKWARD ) {
195
+ ti = params.L - 1 - ti;
196
+ }
197
+
198
+ // Is it a valid access?
199
+ int valid;
200
+ if( GO_BACKWARD ) {
201
+ valid = valid_qk[ri] && ti - ci >= 0;
202
+ } else {
203
+ valid = valid_qk[ri] && ti + ci < params.L;
204
+ }
205
+
206
+ // The extra offset to add.
207
+ if( GO_BACKWARD ) {
208
+ offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
209
+ offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
210
+ } else {
211
+ offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
212
+ offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
213
+ }
214
+
215
+ // Load Q/K if they are valid.
216
+ q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
217
+ k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
218
+ }
219
+ }
220
+
221
+ // For the V tensor, we assign contiguous thread to different loads. So, ti is different.
222
+ int ti = iter + tidx;
223
+ if( GO_BACKWARD ) {
224
+ ti = params.L - 1 - ti;
225
+ }
226
+
227
+ // Is it a valid access?
228
+ int valid_vo = tidx < COLS_PER_ITER;
229
+ if( GO_BACKWARD ) {
230
+ valid_vo &= ti >= 0;
231
+ } else {
232
+ valid_vo &= ti < params.L;
233
+ }
234
+
235
+ // Trigger the loads for V.
236
+ float ldg_v = valid_vo ? *ptr_v : 0.f;
237
+
238
+ // Move the load pointers.
239
+ if( GO_BACKWARD ) {
240
+ ptr_q -= COLS_PER_ITER*params.q_stride_L;
241
+ ptr_k -= COLS_PER_ITER*params.k_stride_L;
242
+ ptr_v -= COLS_PER_ITER*params.v_stride_L;
243
+ } else {
244
+ ptr_q += COLS_PER_ITER*params.q_stride_L;
245
+ ptr_k += COLS_PER_ITER*params.k_stride_L;
246
+ ptr_v += COLS_PER_ITER*params.v_stride_L;
247
+ }
248
+
249
+ // Store to shared memory.
250
+ if( tidx < COLS_PER_ITER ) {
251
+ smem_v[tidx] = ldg_v;
252
+ }
253
+
254
+ // Make sure V is in shared memory.
255
+ __syncthreads();
256
+
257
+ // Read V from shared memory.
258
+ float v[COLS_PER_THREAD];
259
+ #pragma unroll
260
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
261
+ v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
262
+ }
263
+
264
+ // Each thread computes local K*V products.
265
+ float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
266
+ #pragma unroll
267
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
268
+ #pragma unroll
269
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
270
+ kv[ri][ci] = 0.f;
271
+ }
272
+ }
273
+
274
+ // Update the K*V^T product.
275
+ #pragma unroll
276
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
277
+ #pragma unroll
278
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
279
+ kv[ri][ci] += k[ri][ci] * v[ci];
280
+ }
281
+ }
282
+
283
+ // We must perform the prefix sums within the thread-block. Start with the thread.
284
+ #pragma unroll
285
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
286
+ #pragma unroll
287
+ for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
288
+ kv[ri][ci] += kv[ri][ci-1];
289
+ }
290
+ }
291
+
292
+ // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
293
+ #pragma unroll
294
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
295
+ smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
296
+ }
297
+
298
+ // Make sure the data is in shared memory.
299
+ __syncthreads();
300
+
301
+ // Each thread deals with one or more column(s) of the matrix.
302
+ constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
303
+ #pragma unroll
304
+ for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
305
+ if( idx < E ) {
306
+ float sum = smem_reds[idx];
307
+ #pragma unroll
308
+ for( int jj = 1; jj < WARPS; ++jj ) {
309
+ smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
310
+ }
311
+ }
312
+ }
313
+
314
+ // Make sure the reductions are stored in shared memory.
315
+ __syncthreads();
316
+
317
+ // Each thread updates his partial products.
318
+ #pragma unroll
319
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
320
+ float sum = running_kv[ri];
321
+ if( warp > 0 ) {
322
+ sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
323
+ }
324
+ #pragma unroll
325
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
326
+ kv[ri][ci] += sum;
327
+ }
328
+ }
329
+
330
+ // Compute the partial output values for that thread.
331
+ float sum[COLS_PER_THREAD];
332
+ #pragma unroll
333
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
334
+ sum[ci] = q[0][ci] * kv[0][ci];
335
+ #pragma unroll
336
+ for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
337
+ sum[ci] += q[ri][ci] * kv[ri][ci];
338
+ }
339
+ }
340
+
341
+ // Run the parallel reductions inside the warp.
342
+ #pragma unroll
343
+ for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
344
+ #pragma unroll
345
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
346
+ sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
347
+ }
348
+ }
349
+
350
+ // Store the final output to shared memory.
351
+ if( lane == 0 ) {
352
+ #pragma unroll
353
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
354
+ smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
355
+ }
356
+ }
357
+
358
+ // Make sure the data is in shared memory.
359
+ __syncthreads();
360
+
361
+ // Store the output.
362
+ if( valid_vo ) {
363
+ *ptr_o = smem_o[tidx];
364
+ }
365
+
366
+ // Each thread updates his running kv.
367
+ #pragma unroll
368
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
369
+ running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
370
+ }
371
+
372
+ // Move to next location.
373
+ if( GO_BACKWARD ) {
374
+ ptr_o -= COLS_PER_ITER*params.o_stride_L;
375
+ } else {
376
+ ptr_o += COLS_PER_ITER*params.o_stride_L;
377
+ }
378
+ }
379
+ }
380
+
381
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
382
+
383
+ template< int E, bool GO_BACKWARD, int WARPS >
384
+ int lmha_low_occupancy_(const Lmha_params<float> &params) {
385
+
386
+ // Make sure we are not going to launch an invalid grid.
387
+ if( params.H > 65535 || params.B > 65535 ) {
388
+ return 1;
389
+ }
390
+
391
+ // Prepare the grid and trigger the CUDA kernel.
392
+ dim3 grid;
393
+ grid.x = params.M;
394
+ grid.y = params.H;
395
+ grid.z = params.B;
396
+ lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
397
+ return 0;
398
+ }
399
+
400
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
401
+
402
+ template< int E, bool GO_BACKWARD >
403
+ int lmha_low_occupancy_(const Lmha_params<float> &params, int blocks) {
404
+ if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
405
+ return lmha_low_occupancy_<E, GO_BACKWARD, 4>(params);
406
+ } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
407
+ return lmha_low_occupancy_<E, GO_BACKWARD, 8>(params);
408
+ } else {
409
+ return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
410
+ }
411
+ return 1;
412
+ }
413
+
414
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
415
+
416
+ template< int E, typename Params >
417
+ static inline __device__ __host__ int smem_buffer_elts_(const Params &params) {
418
+ int M = round_up(params.M, 4);
419
+ return 2*E + 2*M;
420
+ }
421
+
422
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
423
+
424
+ template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
425
+ __global__
426
+ void lmha_kernel(Lmha_params<float> params) {
427
+
428
+ // Make sure E is a multiple of 4.
429
+ static_assert(E % 4 == 0, "");
430
+
431
+ // The amount of shared memory per buffer (2 buffers for double-buffering).
432
+ const int smem_buffer_elts = smem_buffer_elts_<E>(params);
433
+ // The M dimension for shared memory.
434
+ const int M = round_up(params.M, 4);
435
+
436
+ // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
437
+ extern __shared__ float smem_[];
438
+
439
+ // The various shared memory buffers.
440
+ float *smem_q = &smem_[0*E];
441
+ float *smem_k = &smem_[1*E];
442
+ float *smem_v = &smem_[2*E];
443
+ float *smem_o = &smem_[2*E + M];
444
+
445
+ // The index of the shared memory buffer (for double-buffering).
446
+ int smem_curr = 0;
447
+
448
+ // The sequence processed by that block.
449
+ const int bi = blockIdx.y;
450
+ // The head processed by that block.
451
+ const int hi = blockIdx.x;
452
+
453
+ // The linear index of the thread.
454
+ const int tidx = threadIdx.x;
455
+
456
+ // The offset to the position loaded by the thread in Q.
457
+ int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
458
+ // The offset to the position loaded by the thread in K.
459
+ int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
460
+
461
+ // If we walk backward, account for the extra offset.
462
+ if( GO_BACKWARD ) {
463
+ offset_q += (params.L-1)*params.q_stride_L;
464
+ offset_k += (params.L-1)*params.k_stride_L;
465
+ }
466
+
467
+ // Determine the base pointers for Q and K.
468
+ const float *ptr_q = &params.q[offset_q];
469
+ const float *ptr_k = &params.k[offset_k];
470
+
471
+ // The offset to the position loaded by the thread in V and O.
472
+ int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
473
+ int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;
474
+
475
+ // If we walk backward, account for the extra offset.
476
+ if( GO_BACKWARD ) {
477
+ offset_v += (params.L-1)*params.v_stride_L;
478
+ offset_o += (params.L-1)*params.o_stride_L;
479
+ }
480
+
481
+ // Determine the base pointers for V.
482
+ const float *ptr_v = &params.v[offset_v];
483
+
484
+ // Is it an active Q/K thread?
485
+ const int active_qk = tidx < params.E;
486
+
487
+ // Trigger the memory loads for Q and K.
488
+ float ldg_q = 0.f, ldg_k = 0.f;
489
+ if( active_qk ) {
490
+ ldg_q = *ptr_q;
491
+ ldg_k = *ptr_k;
492
+ }
493
+
494
+ // Is it an active V thread?
495
+ const int active_v = tidx < params.M;
496
+
497
+ // Trigger the memory loads for V.
498
+ float ldg_v = 0.f;
499
+ if( active_v ) {
500
+ ldg_v = *ptr_v;
501
+ }
502
+
503
+ // Move the load pointers.
504
+ if( GO_BACKWARD ) {
505
+ ptr_q -= params.q_stride_L;
506
+ ptr_k -= params.k_stride_L;
507
+ ptr_v -= params.v_stride_L;
508
+ } else {
509
+ ptr_q += params.q_stride_L;
510
+ ptr_k += params.k_stride_L;
511
+ ptr_v += params.v_stride_L;
512
+ }
513
+
514
+ // The number of FLOAT4s per head.
515
+ constexpr int FLOAT4s_PER_HEAD = E / 4;
516
+ // The number of FLOAT4s per thread.
517
+ constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
518
+
519
+ // The storage for the K*V^T values.
520
+ float4 kv[FLOAT4s_PER_THREAD];
521
+ #pragma unroll
522
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
523
+ kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
524
+ }
525
+
526
+ // The output pointer.
527
+ float *out_ptr = &params.out[offset_o];
528
+
529
+ // Store to shared memory Q and K.
530
+ if( tidx < E ) {
531
+ smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
532
+ smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
533
+ }
534
+
535
+ // Store to shared memory V. All threads store valid values.
536
+ if( tidx < M ) {
537
+ smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
538
+ }
539
+
540
+ // The position of the thread in the V dimension.
541
+ int vo = tidx / THREADS_PER_HEAD;
542
+ int vi = tidx % THREADS_PER_HEAD;
543
+
544
+ // Iterate over the timesteps.
545
+ for( int ti = 0; ti < params.L; ++ti ) {
546
+
547
+ // Is it the last iteration?
548
+ int is_last = ti == params.L - 1;
549
+
550
+ // Trigger the next loads for Q and K.
551
+ if( !is_last && active_qk ) {
552
+ ldg_q = *ptr_q;
553
+ ldg_k = *ptr_k;
554
+ }
555
+
556
+ // Trigger the next loads for V.
557
+ if( !is_last && active_v ) {
558
+ ldg_v = *ptr_v;
559
+ }
560
+
561
+ // Move the load pointers.
562
+ if( GO_BACKWARD ) {
563
+ ptr_q -= params.q_stride_L;
564
+ ptr_k -= params.k_stride_L;
565
+ ptr_v -= params.v_stride_L;
566
+ } else {
567
+ ptr_q += params.q_stride_L;
568
+ ptr_k += params.k_stride_L;
569
+ ptr_v += params.v_stride_L;
570
+ }
571
+
572
+ // Make sure the data is in shared memory.
573
+ __syncthreads();
574
+
575
+ // Each thread loads 4 values from K.
576
+ float4 k[FLOAT4s_PER_THREAD];
577
+ #pragma unroll
578
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
579
+ int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
580
+ k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
581
+ }
582
+
583
+ // Each thread loads a single V value.
584
+ float v = 0.f;
585
+ if( vo < params.M ) {
586
+ v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
587
+ }
588
+
589
+ // Update the K*V^T product.
590
+ #pragma unroll
591
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
592
+ kv[ii].x += k[ii].x * v;
593
+ kv[ii].y += k[ii].y * v;
594
+ kv[ii].z += k[ii].z * v;
595
+ kv[ii].w += k[ii].w * v;
596
+ }
597
+
598
+ // Load the Q values from shared memory.
599
+ float4 q[FLOAT4s_PER_THREAD];
600
+ #pragma unroll
601
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
602
+ int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
603
+ q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
604
+ }
605
+
606
+ // Compute the partial output value for that thread.
607
+ float sum = 0.f;
608
+ #pragma unroll
609
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
610
+ sum += q[ii].x * kv[ii].x;
611
+ sum += q[ii].y * kv[ii].y;
612
+ sum += q[ii].z * kv[ii].z;
613
+ sum += q[ii].w * kv[ii].w;
614
+ }
615
+
616
+ // Finalize the computation of the sum (if we have more than 1 thread per head).
617
+ if( THREADS_PER_HEAD > 1 ) {
618
+
619
+ // Finalize the sum for each head.
620
+ #pragma unroll
621
+ for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
622
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
623
+ }
624
+
625
+ // Store to shared memory.
626
+ if( vo < M && vi == 0 ) {
627
+ smem_o[smem_curr*smem_buffer_elts + vo] = sum;
628
+ }
629
+
630
+ // Make sure the data is in shared memory.
631
+ __syncthreads();
632
+
633
+ // Active threads read the data to store.
634
+ if( active_v ) {
635
+ sum = smem_o[smem_curr*smem_buffer_elts + tidx];
636
+ }
637
+
638
+ } // THREADS_PER_HEAD > 1.
639
+
640
+ // Store the output. All the threads are active.
641
+ if( active_v ) {
642
+ *out_ptr = sum;
643
+ }
644
+
645
+ // Move to next location.
646
+ if( GO_BACKWARD ) {
647
+ out_ptr -= params.o_stride_L;
648
+ } else {
649
+ out_ptr += params.o_stride_L;
650
+ }
651
+
652
+ // Move the shared memory buffer.
653
+ smem_curr = (smem_curr + 1) % 2;
654
+
655
+ // Store to shared memory for Q and K.
656
+ if( !is_last && tidx < E ) {
657
+ smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
658
+ smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
659
+ }
660
+
661
+ // Store to shared memory for V.
662
+ if( !is_last && tidx < M ) {
663
+ smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
664
+ }
665
+ }
666
+ }
667
+
668
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
669
+
670
+ template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
671
+ int lmha_(const Lmha_params<float> &params) {
672
+ // The M dimension rounded up to 4.
673
+ int M = round_up(params.M, 4);
674
+
675
+ // The number of threads in the block.
676
+ int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
677
+ if( block > 512 || params.B > 65535 ) {
678
+ return 1;
679
+ }
680
+
681
+ // Prepare the kernel.
682
+ dim3 grid(params.H, params.B);
683
+ size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
684
+ lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
685
+ return 0;
686
+ }
687
+
688
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
689
+
690
+ template< bool GO_BACKWARD >
691
+ int lmha(const Lmha_params<float> &params) {
692
+ int blocks = params.B * params.H;
693
+ int res = 1;
694
+ if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
695
+ if( params.E <= 32 ) {
696
+ res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
697
+ } else if( params.E <= 64 ) {
698
+ res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
699
+ } else if( params.E <= 128 ) {
700
+ res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
701
+ } else if( params.E <= 256 ) {
702
+ res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
703
+ }
704
+ } else {
705
+ if( params.E <= 32 ) {
706
+ res = lmha_< 32, 1, GO_BACKWARD>(params);
707
+ } else if( params.E <= 48 ) {
708
+ res = lmha_< 48, 1, GO_BACKWARD>(params);
709
+ } else if( params.E <= 64 ) {
710
+ res = lmha_< 64, 1, GO_BACKWARD>(params);
711
+ } else if( params.E <= 128 ) {
712
+ res = lmha_<128, 2, GO_BACKWARD>(params);
713
+ } else if( params.E <= 256 ) {
714
+ res = lmha_<256, 4, GO_BACKWARD>(params);
715
+ }
716
+ }
717
+ return res;
718
+ }
719
+
720
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
721
+
722
+ template< typename T >
723
+ inline void set_params(Lmha_params<T> &params,
724
+ const torch::Tensor q,
725
+ const torch::Tensor k,
726
+ const torch::Tensor v,
727
+ torch::Tensor o) {
728
+
729
+ // Define the pointers.
730
+ params.out = o.data_ptr<T>();
731
+ params.q = q.data_ptr<T>();
732
+ params.k = k.data_ptr<T>();
733
+ params.v = v.data_ptr<T>();
734
+
735
+ // Define the strides.
736
+ params.q_stride_B = (int) q.stride(0);
737
+ params.q_stride_H = (int) q.stride(1);
738
+ params.q_stride_L = (int) q.stride(2);
739
+ params.k_stride_B = (int) k.stride(0);
740
+ params.k_stride_H = (int) k.stride(1);
741
+ params.k_stride_L = (int) k.stride(2);
742
+ params.v_stride_B = (int) v.stride(0);
743
+ params.v_stride_H = (int) v.stride(1);
744
+ params.v_stride_L = (int) v.stride(2);
745
+ params.o_stride_B = (int) o.stride(0);
746
+ params.o_stride_H = (int) o.stride(1);
747
+ params.o_stride_L = (int) o.stride(2);
748
+
749
+ // Extract the dimensions.
750
+ int N = q.size(0);
751
+ int H = q.size(1);
752
+ int L = q.size(2);
753
+ int E = q.size(3);
754
+ int M = v.size(3);
755
+
756
+ params.B = N;
757
+ params.L = L;
758
+ params.H = H;
759
+ params.E = E;
760
+ params.M = M;
761
+ }
762
+
763
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
764
+
765
+ int lmha_fwd(const torch::Tensor queries,
766
+ const torch::Tensor keys,
767
+ const torch::Tensor values,
768
+ torch::Tensor product) {
769
+
770
+ // Make sure that we are using the correct GPU device
771
+ torch::DeviceGuard _guard(queries.device());
772
+
773
+ // Make sure the inner-most dimension of the tensors is packed.
774
+ assert(queries.stride(3) == 1);
775
+ assert(keys .stride(3) == 1);
776
+ assert(values .stride(3) == 1);
777
+ assert(product.stride(3) == 1);
778
+
779
+ // Extract the dimensions.
780
+ int N = queries.size(0);
781
+ int H = queries.size(1);
782
+ int L = queries.size(2);
783
+ int E = queries.size(3);
784
+ int M = values.size (3);
785
+
786
+ // The structure of params.
787
+ Lmha_params<float> params;
788
+ set_params(params, queries, keys, values, product);
789
+
790
+ // Launch the kernel.
791
+ return lmha<false>(params);
792
+ }
793
+
794
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
795
+
796
+ template< typename T >
797
+ struct Lmha_bwd_params {
798
+
799
+ // The output buffer for K. Dimensions [B, H, L, D].
800
+ T *out_k;
801
+ // The output buffer for V. Dimensions [B, H, L, D].
802
+ T *out_v;
803
+
804
+ // The input Qs. Dimensions [B, H, L, D].
805
+ const T *q;
806
+ // The input Ks. Dimensions [B, H, L, D].
807
+ const T *k;
808
+ // The input Vs. Dimensions [B, H, L, D].
809
+ const T *v;
810
+ // The input Gs. Dimensions [B, H, L, D].
811
+ const T *g;
812
+
813
+ // The dimensions.
814
+ int B, L, H, M, E;
815
+
816
+ // The strides for the input tensors.
817
+ int q_stride_B, q_stride_L, q_stride_H;
818
+ int k_stride_B, k_stride_L, k_stride_H;
819
+ int v_stride_B, v_stride_L, v_stride_H;
820
+ int g_stride_B, g_stride_L, g_stride_H;
821
+
822
+ // The strides for the outputs.
823
+ int out_k_stride_B, out_k_stride_L, out_k_stride_H;
824
+ int out_v_stride_B, out_v_stride_L, out_v_stride_H;
825
+ };
826
+
827
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
828
+
829
+ template< int D, int THREADS_PER_HEAD >
830
+ __global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
831
+ void lmha_bwd_kernel(Lmha_bwd_params<float> params) {
832
+
833
+ // Make sure D is a multiple of 4.
834
+ static_assert(D % 4 == 0, "");
835
+
836
+ // The shared memory buffers.
837
+ __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];
838
+
839
+ // The index of the shared memory buffer (for double-buffering).
840
+ int smem_curr = 0;
841
+
842
+ // The sequence processed by that block.
843
+ const int bi = blockIdx.y;
844
+ // The head processed by that block.
845
+ const int hi = blockIdx.x;
846
+
847
+ // The linear index of the thread.
848
+ const int tidx = threadIdx.x;
849
+
850
+ // Split the threads into two slices.
851
+ int so = tidx / (D*THREADS_PER_HEAD);
852
+ int si = tidx % (D*THREADS_PER_HEAD);
853
+
854
+ // The strides for B/L/H for the Q/G tensors.
855
+ int qg_stride_B, qg_stride_L, qg_stride_H;
856
+ if( so == 0 ) {
857
+ qg_stride_B = params.q_stride_B;
858
+ qg_stride_L = params.q_stride_L;
859
+ qg_stride_H = params.q_stride_H;
860
+ } else {
861
+ qg_stride_B = params.g_stride_B;
862
+ qg_stride_L = params.g_stride_L;
863
+ qg_stride_H = params.g_stride_H;
864
+ }
865
+
866
+ // The strides for B/L/H for the K/V tensors.
867
+ int kv_stride_B, kv_stride_L, kv_stride_H;
868
+ if( so == 0 ) {
869
+ kv_stride_B = params.k_stride_B;
870
+ kv_stride_L = params.k_stride_L;
871
+ kv_stride_H = params.k_stride_H;
872
+ } else {
873
+ kv_stride_B = params.v_stride_B;
874
+ kv_stride_L = params.v_stride_L;
875
+ kv_stride_H = params.v_stride_H;
876
+ }
877
+
878
+ // The hidden size.
879
+ int hidden_size_per_head = 0;
880
+ if( so == 0 ) {
881
+ hidden_size_per_head = params.E;
882
+ } else {
883
+ hidden_size_per_head = params.M;
884
+ }
885
+
886
+ // Where to start reading from.
887
+ int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
888
+ int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;
889
+
890
+ // We walk backward, account for the extra offset.
891
+ offset_qg += (params.L-1)*qg_stride_L;
892
+ offset_kv += (params.L-1)*kv_stride_L;
893
+
894
+ // Determine the base pointers for Q, K, V and G.
895
+ const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
896
+ const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];
897
+
898
+ // Is it an active thread?
899
+ const int active = si < hidden_size_per_head;
900
+
901
+ // Trigger the memory loads for Q, K, V and G.
902
+ float ldg_qg = 0.f, ldg_kv = 0.f;
903
+ if( active ) {
904
+ ldg_qg = *ptr_qg;
905
+ ldg_kv = *ptr_kv;
906
+ }
907
+
908
+ // Move the load pointers (backward).
909
+ ptr_qg -= qg_stride_L;
910
+ ptr_kv -= kv_stride_L;
911
+
912
+ // The number of FLOAT4s per head.
913
+ constexpr int FLOAT4s_PER_HEAD = D / 4;
914
+ // The number of FLOAT4s per thread.
915
+ constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
916
+
917
+ // The storage for the G*Q^T or Q^T*G values.
918
+ float4 gq[FLOAT4s_PER_THREAD];
919
+ #pragma unroll
920
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
921
+ gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
922
+ }
923
+
924
+ // The strides for B/L/H for the K/V tensors.
925
+ int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
926
+ if( so == 0 ) {
927
+ out_kv_stride_B = params.out_k_stride_B;
928
+ out_kv_stride_L = params.out_k_stride_L;
929
+ out_kv_stride_H = params.out_k_stride_H;
930
+ } else {
931
+ out_kv_stride_B = params.out_v_stride_B;
932
+ out_kv_stride_L = params.out_v_stride_L;
933
+ out_kv_stride_H = params.out_v_stride_H;
934
+ }
935
+
936
+ // Where to start reading from.
937
+ int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;
938
+
939
+ // We walk backward, account for the extra offset.
940
+ offset_out_kv += (params.L-1)*out_kv_stride_L;
941
+
942
+ // The output pointer.
943
+ float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];
944
+
945
+ // Store to shared memory.
946
+ if( si < D ) {
947
+ smem_[smem_curr].qg[so*D + si] = ldg_qg;
948
+ smem_[smem_curr].kv[so*D + si] = ldg_kv;
949
+ }
950
+
951
+ // The position of the thread in the output dimension.
952
+ int oo = si / THREADS_PER_HEAD % D;
953
+ int oi = si % THREADS_PER_HEAD * 4;
954
+
955
+ // Iterate over the timesteps.
956
+ for( int ti = 0; ti < params.L; ++ti ) {
957
+
958
+ // Is it the last iteration?
959
+ int is_last = ti == params.L - 1;
960
+
961
+ // Trigger the next loads.
962
+ if( !is_last && active ) {
963
+ ldg_qg = *ptr_qg;
964
+ ldg_kv = *ptr_kv;
965
+ }
966
+
967
+ // Move the load pointers.
968
+ ptr_qg -= qg_stride_L;
969
+ ptr_kv -= kv_stride_L;
970
+
971
+ // Make sure the data is in shared memory.
972
+ __syncthreads();
973
+
974
+ // Each thread loads 4 values from G or Q.
975
+ float4 g[FLOAT4s_PER_THREAD];
976
+ #pragma unroll
977
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
978
+ float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
979
+ g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
980
+ }
981
+
982
+ // Each thread loads a single from Q or G value.
983
+ float q = smem_[smem_curr].qg[so*D + oo];
984
+
985
+ // Update the G*Q^T or Q*G^T product.
986
+ #pragma unroll
987
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
988
+ gq[ii].x += g[ii].x * q;
989
+ gq[ii].y += g[ii].y * q;
990
+ gq[ii].z += g[ii].z * q;
991
+ gq[ii].w += g[ii].w * q;
992
+ }
993
+
994
+ // Load the V or K values from shared memory.
995
+ float4 v[FLOAT4s_PER_THREAD];
996
+ #pragma unroll
997
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
998
+ float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
999
+ v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
1000
+ }
1001
+
1002
+ // Compute the partial output value for that thread.
1003
+ float sum = 0.f;
1004
+ #pragma unroll
1005
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
1006
+ sum += v[ii].x * gq[ii].x;
1007
+ sum += v[ii].y * gq[ii].y;
1008
+ sum += v[ii].z * gq[ii].z;
1009
+ sum += v[ii].w * gq[ii].w;
1010
+ }
1011
+
1012
+ // Finalize the computation of the sum (if we have more than 1 thread per head).
1013
+ if( THREADS_PER_HEAD > 1 ) {
1014
+
1015
+ // Finalize the sum for each head.
1016
+ #pragma unroll
1017
+ for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
1018
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
1019
+ }
1020
+
1021
+ // Store to shared memory.
1022
+ if( oi == 0 ) {
1023
+ smem_[smem_curr].out_kv[so*D + oo] = sum;
1024
+ }
1025
+
1026
+ // Make sure the data is in shared memory.
1027
+ __syncthreads();
1028
+
1029
+ // Active threads read the data to store.
1030
+ if( si < hidden_size_per_head ) {
1031
+ sum = smem_[smem_curr].out_kv[so*D + si];
1032
+ }
1033
+
1034
+ } // THREADS_PER_HEAD > 1.
1035
+
1036
+ // Store the output. All the threads are active.
1037
+ if( si < hidden_size_per_head ) {
1038
+ *ptr_out_kv = sum;
1039
+ }
1040
+
1041
+ // Move to next location.
1042
+ ptr_out_kv -= out_kv_stride_L;
1043
+
1044
+ // Move the shared memory buffer.
1045
+ smem_curr = (smem_curr + 1) % 2;
1046
+
1047
+ // Store to shared memory for Q and K.
1048
+ if( !is_last && si < D ) {
1049
+ smem_[smem_curr].qg[so*D + si] = ldg_qg;
1050
+ smem_[smem_curr].kv[so*D + si] = ldg_kv;
1051
+ }
1052
+ }
1053
+ }
1054
+
1055
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1056
+
1057
+ template< int D, int THREADS_PER_HEAD >
1058
+ int lmha_bwd_(const Lmha_bwd_params<float> &params) {
1059
+ int block = D*THREADS_PER_HEAD*2;
1060
+ if( block >= 1024 || params.B > 65535 ) {
1061
+ return 1;
1062
+ }
1063
+ dim3 grid(params.H, params.B);
1064
+ lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
1065
+ return 0;
1066
+ }
1067
+
1068
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1069
+
1070
+ int lmha_bwd(const Lmha_bwd_params<float> &params) {
1071
+ int blocks = params.B * params.H;
1072
+ if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
1073
+ return 1;
1074
+ }
1075
+
1076
+ int hidden_size_per_head = max(params.E, params.M);
1077
+ int res = 1;
1078
+ if( hidden_size_per_head <= 32 ) {
1079
+ res = lmha_bwd_< 32, 1>(params);
1080
+ } else if( hidden_size_per_head <= 64 ) {
1081
+ res = lmha_bwd_< 64, 1>(params);
1082
+ } else if( hidden_size_per_head <= 128 ) {
1083
+ res = lmha_bwd_<128, 2>(params);
1084
+ } else if( hidden_size_per_head <= 256 ) {
1085
+ res = lmha_bwd_<256, 4>(params);
1086
+ }
1087
+ return res;
1088
+ }
1089
+
1090
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1091
+
1092
+ int lmha_bwd(const torch::Tensor queries,
1093
+ const torch::Tensor keys,
1094
+ const torch::Tensor values,
1095
+ const torch::Tensor grad_out,
1096
+ torch::Tensor grad_queries,
1097
+ torch::Tensor grad_keys,
1098
+ torch::Tensor grad_values) {
1099
+
1100
+ // Make sure that we are using the correct GPU device
1101
+ torch::DeviceGuard _guard(queries.device());
1102
+
1103
+ // Make sure the inner-most dimension of the tensors is packed.
1104
+ assert(queries .stride(3) == 1);
1105
+ assert(keys .stride(3) == 1);
1106
+ assert(values .stride(3) == 1);
1107
+ assert(grad_out .stride(3) == 1);
1108
+ assert(grad_queries.stride(3) == 1);
1109
+ assert(grad_keys .stride(3) == 1);
1110
+ assert(grad_values .stride(3) == 1);
1111
+
1112
+ // Extract the dimensions.
1113
+ int N = queries.size(0);
1114
+ int H = queries.size(1);
1115
+ int L = queries.size(2);
1116
+ int E = queries.size(3);
1117
+ int M = values.size (3);
1118
+
1119
+ // Gradient on Q.
1120
+
1121
+ // The structure of params.
1122
+ Lmha_params<float> params;
1123
+ set_params(params, grad_out, values, keys, grad_queries);
1124
+
1125
+ // Launch the kernel.
1126
+ int res = lmha<false>(params);
1127
+ if( res ) {
1128
+ return res;
1129
+ }
1130
+
1131
+ // Gradient on K and V together.
1132
+
1133
+ Lmha_bwd_params<float> bwd_params;
1134
+ bwd_params.out_k = grad_keys.data_ptr<float>();
1135
+ bwd_params.out_v = grad_values.data_ptr<float>();
1136
+ bwd_params.q = queries.data_ptr<float>();
1137
+ bwd_params.k = keys.data_ptr<float>();
1138
+ bwd_params.v = values.data_ptr<float>();
1139
+ bwd_params.g = grad_out.data_ptr<float>();
1140
+
1141
+ bwd_params.B = N;
1142
+ bwd_params.L = L;
1143
+ bwd_params.H = H;
1144
+ bwd_params.E = E;
1145
+ bwd_params.M = M;
1146
+
1147
+ bwd_params.q_stride_B = queries.stride(0);
1148
+ bwd_params.q_stride_H = queries.stride(1);
1149
+ bwd_params.q_stride_L = queries.stride(2);
1150
+ bwd_params.k_stride_B = keys.stride(0);
1151
+ bwd_params.k_stride_H = keys.stride(1);
1152
+ bwd_params.k_stride_L = keys.stride(2);
1153
+ bwd_params.v_stride_B = values.stride(0);
1154
+ bwd_params.v_stride_H = values.stride(1);
1155
+ bwd_params.v_stride_L = values.stride(2);
1156
+ bwd_params.g_stride_B = grad_out.stride(0);
1157
+ bwd_params.g_stride_H = grad_out.stride(1);
1158
+ bwd_params.g_stride_L = grad_out.stride(2);
1159
+
1160
+ bwd_params.out_k_stride_B = grad_keys.stride(0);
1161
+ bwd_params.out_k_stride_H = grad_keys.stride(1);
1162
+ bwd_params.out_k_stride_L = grad_keys.stride(2);
1163
+ bwd_params.out_v_stride_B = grad_values.stride(0);
1164
+ bwd_params.out_v_stride_H = grad_values.stride(1);
1165
+ bwd_params.out_v_stride_L = grad_values.stride(2);
1166
+
1167
+ // Try to run the fused kernel.
1168
+ int fallback = lmha_bwd(bwd_params);
1169
+
1170
+ // If it failed, fallback on separate kernels for K and V.
1171
+ if( fallback ) {
1172
+
1173
+ // Gradient on K.
1174
+
1175
+ // Launch the kernel.
1176
+ set_params(params, values, grad_out, queries, grad_keys);
1177
+ res = lmha<true>(params);
1178
+ if( res ) {
1179
+ return res;
1180
+ }
1181
+
1182
+ // Gradient on V.
1183
+
1184
+ // Launch the kernel.
1185
+ set_params(params, keys, queries, grad_out, grad_values);
1186
+ return lmha<true>(params);
1187
+ }
1188
+
1189
+ // It worked...
1190
+ return 0;
1191
+ }
1192
+
1193
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1194
+
1195
+ } // namespace nvidia
1196
+ #endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
1197
+
1198
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1199
+
1200
+ typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;
1201
+
1202
+ #define E_BLOCK_SIZE 8
1203
+
1204
+ __global__ void causal_dot_product_kernel(
1205
+ const float_accessor queries,
1206
+ const float_accessor keys,
1207
+ const float_accessor values,
1208
+ float_accessor result,
1209
+ const int N,
1210
+ const int H,
1211
+ const int L,
1212
+ const int E,
1213
+ const int M
1214
+ ) {
1215
+ int n = blockIdx.y;
1216
+ int h = blockIdx.z;
1217
+
1218
+ int e_start = blockIdx.x * E_BLOCK_SIZE;
1219
+ int m = threadIdx.x % M;
1220
+
1221
+ extern __shared__ float shared_mem[];
1222
+ float* shared_kv = shared_mem;
1223
+
1224
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1225
+ shared_kv[m + e_local * M] = 0;
1226
+ }
1227
+
1228
+ for (int t=0; t<L; t++) {
1229
+ float res = 0;
1230
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1231
+ shared_kv[e_local*M + m] += keys[n][h][t][e_local + e_start] * values[n][h][t][m];
1232
+ res += queries[n][h][t][e_local + e_start] * shared_kv[e_local*M + m];
1233
+ }
1234
+ atomicAdd(
1235
+ &result[n][h][t][m],
1236
+ res
1237
+ );
1238
+ }
1239
+ }
1240
+
1241
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1242
+
1243
+ void causal_dot_product_(const torch::Tensor queries,
1244
+ const torch::Tensor keys,
1245
+ const torch::Tensor values,
1246
+ torch::Tensor product) {
1247
+ // Make sure that we are using the correct GPU device
1248
+ torch::DeviceGuard _guard(queries.device());
1249
+
1250
+ int N = queries.size(0);
1251
+ int H = queries.size(1);
1252
+ int L = queries.size(2);
1253
+ int E = queries.size(3);
1254
+ int M = values.size(3);
1255
+
1256
+ const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
1257
+
1258
+ dim3 blockDim(M, 1, 1);
1259
+ dim3 gridDim(blocks_per_sequence, N, H);
1260
+ const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
1261
+
1262
+ causal_dot_product_kernel<<<gridDim, blockDim, shared_mem_forward>>>(
1263
+ queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1264
+ keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1265
+ values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1266
+ product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1267
+ N, H, L, E, M
1268
+ );
1269
+ }
1270
+
1271
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1272
+
1273
+ void causal_dot_product(const torch::Tensor queries,
1274
+ const torch::Tensor keys,
1275
+ const torch::Tensor values,
1276
+ torch::Tensor product) {
1277
+ #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
1278
+ int fallback = nvidia::lmha_fwd(queries, keys, values, product);
1279
+ #else
1280
+ int fallback = 1;
1281
+ #endif
1282
+ if( fallback ) {
1283
+ causal_dot_product_(queries, keys, values, product);
1284
+ }
1285
+ }
1286
+
1287
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1288
+
1289
+ #define M_BLOCK_SIZE 4
1290
+
1291
+ // we need shared memory to store
1292
+ // kv
1293
+ // Backward direction
1294
+ // kv_backwards
1295
+ // Shared memory usage
1296
+ __global__ void causal_dot_backward_query_key_kernel(
1297
+ const float_accessor queries,
1298
+ const float_accessor keys,
1299
+ const float_accessor values,
1300
+ const float_accessor grad_out,
1301
+ float_accessor grad_queries,
1302
+ float_accessor grad_keys,
1303
+ int N,
1304
+ int H,
1305
+ int L,
1306
+ int E,
1307
+ int M
1308
+ ) {
1309
+ int n = blockIdx.y;
1310
+ int h = blockIdx.z;
1311
+
1312
+ int m_start = blockIdx.x * M_BLOCK_SIZE;
1313
+ int e = threadIdx.x % E;
1314
+
1315
+ extern __shared__ float shared_mem[];
1316
+ const int shared_kv_size = M_BLOCK_SIZE * E;
1317
+ float* shared_kv = shared_mem;
1318
+ float* shared_kv_bw = shared_mem + shared_kv_size;
1319
+
1320
+ for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
1321
+ shared_kv[m_local * E + e] = 0;
1322
+ shared_kv_bw[m_local * E + e] = 0;
1323
+ }
1324
+
1325
+ for (int l=0; l<L; l++) {
1326
+ float res = 0, res_bw = 0;
1327
+ int l_b = L - l - 1;
1328
+ for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
1329
+ shared_kv[m_local*E + e] += keys[n][h][l][e] * values[n][h][l][m_start + m_local];
1330
+ shared_kv_bw[m_local*E + e] += queries[n][h][l_b][e] * grad_out[n][h][l_b][m_start + m_local];
1331
+ res += grad_out[n][h][l][m_start + m_local] * shared_kv[m_local*E + e];
1332
+ res_bw += values[n][h][l_b][m_start + m_local] * shared_kv_bw[m_local*E + e];
1333
+ }
1334
+ atomicAdd(
1335
+ &grad_queries[n][h][l][e],
1336
+ res
1337
+ );
1338
+ atomicAdd(
1339
+ &grad_keys[n][h][l_b][e],
1340
+ res_bw
1341
+ );
1342
+ }
1343
+ }
1344
+
1345
+
1346
+ __global__ void causal_dot_backward_value_kernel(
1347
+ const float_accessor queries,
1348
+ const float_accessor keys,
1349
+ const float_accessor values,
1350
+ const float_accessor grad_out,
1351
+ float_accessor grad_keys,
1352
+ float_accessor grad_values,
1353
+ int N,
1354
+ int H,
1355
+ int L,
1356
+ int E,
1357
+ int M
1358
+ ) {
1359
+ int n = blockIdx.y;
1360
+ int h = blockIdx.z;
1361
+
1362
+ int e_start = blockIdx.x * E_BLOCK_SIZE;
1363
+ int m = threadIdx.x % M;
1364
+
1365
+ extern __shared__ float shared_mem[];
1366
+ float* shared_kv = shared_mem;
1367
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1368
+ shared_kv[m + e_local * M] = 0;
1369
+ }
1370
+
1371
+ for (int l = 0; l < L; l++) {
1372
+ int l_b = L - l -1;
1373
+ float res = 0;
1374
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1375
+ shared_kv[e_local*M + m] += queries[n][h][l_b][e_start + e_local] * grad_out[n][h][l_b][m];
1376
+ res += keys[n][h][l_b][e_start + e_local] * shared_kv[e_local*M + m];
1377
+ }
1378
+ atomicAdd(
1379
+ &grad_values[n][h][l_b][m],
1380
+ res
1381
+ );
1382
+ }
1383
+ }
1384
+
1385
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1386
+
1387
+ void causal_dot_backward_(const torch::Tensor queries,
1388
+ const torch::Tensor keys,
1389
+ const torch::Tensor values,
1390
+ const torch::Tensor grad_out,
1391
+ torch::Tensor grad_queries,
1392
+ torch::Tensor grad_keys,
1393
+ torch::Tensor grad_values) {
1394
+
1395
+ // Make sure that we are using the correct GPU device
1396
+ torch::DeviceGuard _guard(queries.device());
1397
+
1398
+ int N = queries.size(0);
1399
+ int H = queries.size(1);
1400
+ int L = queries.size(2);
1401
+ int E = queries.size(3);
1402
+ int M = values.size(3);
1403
+
1404
+ const int blocks_per_sequence = (M + M_BLOCK_SIZE - 1) / M_BLOCK_SIZE;
1405
+
1406
+ dim3 blockDim(E, 1, 1);
1407
+ dim3 gridDim(blocks_per_sequence, N, H);
1408
+ const int shared_mem_qk_backward = 2 * M_BLOCK_SIZE * E * sizeof(float);
1409
+
1410
+ causal_dot_backward_query_key_kernel<<<gridDim, blockDim, shared_mem_qk_backward>>>(
1411
+ queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1412
+ keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1413
+ values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1414
+ grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1415
+ grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1416
+ grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1417
+ N, H, L, E, M
1418
+ );
1419
+
1420
+ const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
1421
+
1422
+ dim3 blockDimv(M, 1, 1);
1423
+ dim3 gridDimv(blocks_per_sequence_value, N, H);
1424
+ const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float);
1425
+ causal_dot_backward_value_kernel<<<gridDimv, blockDimv, shared_mem_v_backward>>>(
1426
+ queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1427
+ keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1428
+ values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1429
+ grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1430
+ grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1431
+ grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1432
+ N, H, L, E, M
1433
+ );
1434
+ }
1435
+
1436
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1437
+
1438
+ void causal_dot_backward(const torch::Tensor queries,
1439
+ const torch::Tensor keys,
1440
+ const torch::Tensor values,
1441
+ const torch::Tensor grad_out,
1442
+ torch::Tensor grad_queries,
1443
+ torch::Tensor grad_keys,
1444
+ torch::Tensor grad_values) {
1445
+ #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
1446
+ int fallback = nvidia::lmha_bwd(queries,
1447
+ keys,
1448
+ values,
1449
+ grad_out,
1450
+ grad_queries,
1451
+ grad_keys,
1452
+ grad_values);
1453
+ #else
1454
+ int fallback = 1;
1455
+ #endif
1456
+ if( fallback ) {
1457
+ // Make sure that the gradient tensors are 0. This is needed because the
1458
+ // bwd pass might have partially executed and filled in some values in
1459
+ // grad_queries or grad_keys.
1460
+ //
1461
+ // This adds a small overhead every time we have to fall back to the old
1462
+ // kernel for the backward pass.
1463
+ grad_queries.zero_();
1464
+ grad_keys.zero_();
1465
+ causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
1466
+ }
1467
+ }
1468
+
1469
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1470
+
1471
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1472
+ m.def(
1473
+ "causal_dot_product",
1474
+ &causal_dot_product,
1475
+ "Compute the weighted sum of values but attending only to previous "
1476
+ "values."
1477
+ );
1478
+ m.def(
1479
+ "causal_dot_backward",
1480
+ &causal_dot_backward,
1481
+ "Compute the gradients for the causal dot product."
1482
+ );
1483
+ }
csrc/causal_attention_kv_cuda.cu ADDED
@@ -0,0 +1,1483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ // Written by Angelos Katharopoulos <[email protected]>,
4
+ // Apoorv Vyas <[email protected]>
5
+ //
6
+
7
+ //
8
+ // For modifications made inside namespace nvidia (authored by jdemouth):
9
+ //
10
+ // Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
11
+ //
12
+ // Permission is hereby granted, free of charge, to any person obtaining a copy of
13
+ // this software and associated documentation files (the "Software"), to deal in
14
+ // the Software without restriction, including without limitation the rights to
15
+ // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
16
+ // the Software, and to permit persons to whom the Software is furnished to do so,
17
+ // subject to the following conditions:
18
+ //
19
+ // The above copyright notice and this permission notice shall be included in all
20
+ // copies or substantial portions of the Software.
21
+ //
22
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
24
+ // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
25
+ // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
26
+ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
27
+ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28
+ //
29
+
30
+ #include <torch/extension.h>
31
+ #include <assert.h>
32
+ #include <stdio.h>
33
+
34
+ #define ENABLE_NVIDIA_OPTIMIZATIONS
35
+
36
+ #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
37
+ namespace nvidia {
38
+
39
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ constexpr int THREADS_PER_WARP = 32;
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ static inline __device__ __host__ int div_up(int m, int n) {
50
+ return (m + n-1) / n;
51
+ }
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ static inline __device__ __host__ int round_up(int m, int n) {
56
+ return div_up(m, n) * n;
57
+ }
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ template< typename T >
62
+ struct Lmha_params {
63
+
64
+ // The output buffer. Dimensions [B, H, L, M].
65
+ T *out;
66
+
67
+ // The input Qs. Dimensions [B, H, L, E].
68
+ const T *q;
69
+ // The input Ks. Dimensions [B, H, L, E].
70
+ const T *k;
71
+ // The input Vs. Dimensions [B, H, L, M].
72
+ const T *v;
73
+
74
+ // The different dimensions.
75
+ int B, L, H, E, M;
76
+
77
+ // The strides for the different tensors.
78
+ int q_stride_B, q_stride_H, q_stride_L;
79
+ int k_stride_B, k_stride_H, k_stride_L;
80
+ int v_stride_B, v_stride_H, v_stride_L;
81
+ int o_stride_B, o_stride_H, o_stride_L;
82
+ };
83
+
84
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
85
+
86
+ template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
87
+ __global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
88
+ void lmha_low_occupancy_kernel(Lmha_params<float> params) {
89
+
90
+ // The number of threads per block.
91
+ constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
92
+ // The number of rows per thread.
93
+ constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
94
+ // The number of steps per iteration.
95
+ constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;
96
+
97
+ // Make sure E is a multiple of the warp size.
98
+ static_assert(E % THREADS_PER_WARP == 0, "");
99
+
100
+ // Shared memory to store V/O.
101
+ __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
102
+ // Shared memory buffer to performance the reductions.
103
+ __shared__ float smem_reds[E * WARPS];
104
+
105
+ // The sequence processed by that block.
106
+ const int bi = blockIdx.z;
107
+ // The head processed by that block.
108
+ const int hi = blockIdx.y;
109
+ // The hidden cell in the V/output buffers.
110
+ const int vi = blockIdx.x;
111
+
112
+ // The linear index of the thread.
113
+ const int tidx = threadIdx.x;
114
+
115
+ // Decompose the block in warp/lane.
116
+ const int warp = tidx / THREADS_PER_WARP;
117
+ const int lane = tidx % THREADS_PER_WARP;
118
+
119
+ // The base offset loaded by the thread in Q and K.
120
+ int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
121
+ int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;
122
+
123
+ // If we walk backward, account for the extra offset.
124
+ if( GO_BACKWARD ) {
125
+ offset_q += (params.L-1)*params.q_stride_L;
126
+ offset_k += (params.L-1)*params.k_stride_L;
127
+ }
128
+
129
+ // Position the warp at the beginning of the proper timestep.
130
+ if( GO_BACKWARD ) {
131
+ offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
132
+ offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
133
+ } else {
134
+ offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
135
+ offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
136
+ }
137
+
138
+ // Determine the base pointers for Q and K.
139
+ const float *ptr_q = &params.q[offset_q];
140
+ const float *ptr_k = &params.k[offset_k];
141
+
142
+ // Is a given row valid?
143
+ int valid_qk[ROWS_PER_THREAD];
144
+ #pragma unroll
145
+ for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
146
+ valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
147
+ }
148
+
149
+ // The offset to the position loaded by the thread in V.
150
+ int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
151
+ int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;
152
+
153
+ // If we walk backward, account for the extra offset.
154
+ if( GO_BACKWARD ) {
155
+ offset_v += (params.L-1)*params.v_stride_L;
156
+ offset_o += (params.L-1)*params.o_stride_L;
157
+ }
158
+
159
+ // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
160
+ if( GO_BACKWARD ) {
161
+ offset_v -= tidx*params.v_stride_L;
162
+ offset_o -= tidx*params.o_stride_L;
163
+ } else {
164
+ offset_v += tidx*params.v_stride_L;
165
+ offset_o += tidx*params.o_stride_L;
166
+ }
167
+
168
+ // Determine the base pointer for V.
169
+ const float *ptr_v = &params.v[offset_v];
170
+ // The output pointer.
171
+ float *ptr_o = &params.out[offset_o];
172
+
173
+ // The running KVs.
174
+ float running_kv[ROWS_PER_THREAD];
175
+ #pragma unroll
176
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
177
+ running_kv[ri] = 0.f;
178
+ }
179
+
180
+ // Iterate over the timesteps. TODO: Use params.loop_count!!!
181
+ for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {
182
+
183
+ // Each thread loads a matrix of elements.
184
+ float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];
185
+
186
+ // Trigger the memory loads for Q and K.
187
+ #pragma unroll
188
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
189
+ #pragma unroll
190
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
191
+
192
+ // For Q/K, each warp loads from various timesteps.
193
+ int ti = iter + warp*COLS_PER_THREAD;
194
+ if( GO_BACKWARD ) {
195
+ ti = params.L - 1 - ti;
196
+ }
197
+
198
+ // Is it a valid access?
199
+ int valid;
200
+ if( GO_BACKWARD ) {
201
+ valid = valid_qk[ri] && ti - ci >= 0;
202
+ } else {
203
+ valid = valid_qk[ri] && ti + ci < params.L;
204
+ }
205
+
206
+ // The extra offset to add.
207
+ if( GO_BACKWARD ) {
208
+ offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
209
+ offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
210
+ } else {
211
+ offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
212
+ offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
213
+ }
214
+
215
+ // Load Q/K if they are valid.
216
+ q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
217
+ k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
218
+ }
219
+ }
220
+
221
+ // For the V tensor, we assign contiguous thread to different loads. So, ti is different.
222
+ int ti = iter + tidx;
223
+ if( GO_BACKWARD ) {
224
+ ti = params.L - 1 - ti;
225
+ }
226
+
227
+ // Is it a valid access?
228
+ int valid_vo = tidx < COLS_PER_ITER;
229
+ if( GO_BACKWARD ) {
230
+ valid_vo &= ti >= 0;
231
+ } else {
232
+ valid_vo &= ti < params.L;
233
+ }
234
+
235
+ // Trigger the loads for V.
236
+ float ldg_v = valid_vo ? *ptr_v : 0.f;
237
+
238
+ // Move the load pointers.
239
+ if( GO_BACKWARD ) {
240
+ ptr_q -= COLS_PER_ITER*params.q_stride_L;
241
+ ptr_k -= COLS_PER_ITER*params.k_stride_L;
242
+ ptr_v -= COLS_PER_ITER*params.v_stride_L;
243
+ } else {
244
+ ptr_q += COLS_PER_ITER*params.q_stride_L;
245
+ ptr_k += COLS_PER_ITER*params.k_stride_L;
246
+ ptr_v += COLS_PER_ITER*params.v_stride_L;
247
+ }
248
+
249
+ // Store to shared memory.
250
+ if( tidx < COLS_PER_ITER ) {
251
+ smem_v[tidx] = ldg_v;
252
+ }
253
+
254
+ // Make sure V is in shared memory.
255
+ __syncthreads();
256
+
257
+ // Read V from shared memory.
258
+ float v[COLS_PER_THREAD];
259
+ #pragma unroll
260
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
261
+ v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
262
+ }
263
+
264
+ // Each thread computes local K*V products.
265
+ float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
266
+ #pragma unroll
267
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
268
+ #pragma unroll
269
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
270
+ kv[ri][ci] = 0.f;
271
+ }
272
+ }
273
+
274
+ // Update the K*V^T product.
275
+ #pragma unroll
276
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
277
+ #pragma unroll
278
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
279
+ kv[ri][ci] += k[ri][ci] * v[ci];
280
+ }
281
+ }
282
+
283
+ // We must perform the prefix sums within the thread-block. Start with the thread.
284
+ #pragma unroll
285
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
286
+ #pragma unroll
287
+ for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
288
+ kv[ri][ci] += kv[ri][ci-1];
289
+ }
290
+ }
291
+
292
+ // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
293
+ #pragma unroll
294
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
295
+ smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
296
+ }
297
+
298
+ // Make sure the data is in shared memory.
299
+ __syncthreads();
300
+
301
+ // Each thread deals with one or more column(s) of the matrix.
302
+ constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
303
+ #pragma unroll
304
+ for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
305
+ if( idx < E ) {
306
+ float sum = smem_reds[idx];
307
+ #pragma unroll
308
+ for( int jj = 1; jj < WARPS; ++jj ) {
309
+ smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
310
+ }
311
+ }
312
+ }
313
+
314
+ // Make sure the reductions are stored in shared memory.
315
+ __syncthreads();
316
+
317
+ // Each thread updates his partial products.
318
+ #pragma unroll
319
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
320
+ float sum = running_kv[ri];
321
+ if( warp > 0 ) {
322
+ sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
323
+ }
324
+ #pragma unroll
325
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
326
+ kv[ri][ci] += sum;
327
+ }
328
+ }
329
+
330
+ // Compute the partial output values for that thread.
331
+ float sum[COLS_PER_THREAD];
332
+ #pragma unroll
333
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
334
+ sum[ci] = q[0][ci] * kv[0][ci];
335
+ #pragma unroll
336
+ for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
337
+ sum[ci] += q[ri][ci] * kv[ri][ci];
338
+ }
339
+ }
340
+
341
+ // Run the parallel reductions inside the warp.
342
+ #pragma unroll
343
+ for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
344
+ #pragma unroll
345
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
346
+ sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
347
+ }
348
+ }
349
+
350
+ // Store the final output to shared memory.
351
+ if( lane == 0 ) {
352
+ #pragma unroll
353
+ for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
354
+ smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
355
+ }
356
+ }
357
+
358
+ // Make sure the data is in shared memory.
359
+ __syncthreads();
360
+
361
+ // Store the output.
362
+ if( valid_vo ) {
363
+ *ptr_o = smem_o[tidx];
364
+ }
365
+
366
+ // Each thread updates his running kv.
367
+ #pragma unroll
368
+ for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
369
+ running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
370
+ }
371
+
372
+ // Move to next location.
373
+ if( GO_BACKWARD ) {
374
+ ptr_o -= COLS_PER_ITER*params.o_stride_L;
375
+ } else {
376
+ ptr_o += COLS_PER_ITER*params.o_stride_L;
377
+ }
378
+ }
379
+ }
380
+
381
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
382
+
383
+ template< int E, bool GO_BACKWARD, int WARPS >
384
+ int lmha_low_occupancy_(const Lmha_params<float> &params) {
385
+
386
+ // Make sure we are not going to launch an invalid grid.
387
+ if( params.H > 65535 || params.B > 65535 ) {
388
+ return 1;
389
+ }
390
+
391
+ // Prepare the grid and trigger the CUDA kernel.
392
+ dim3 grid;
393
+ grid.x = params.M;
394
+ grid.y = params.H;
395
+ grid.z = params.B;
396
+ lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
397
+ return 0;
398
+ }
399
+
400
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
401
+
402
+ template< int E, bool GO_BACKWARD >
403
+ int lmha_low_occupancy_(const Lmha_params<float> &params, int blocks) {
404
+ if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
405
+ return lmha_low_occupancy_<E, GO_BACKWARD, 4>(params);
406
+ } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
407
+ return lmha_low_occupancy_<E, GO_BACKWARD, 8>(params);
408
+ } else {
409
+ return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
410
+ }
411
+ return 1;
412
+ }
413
+
414
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
415
+
416
+ template< int E, typename Params >
417
+ static inline __device__ __host__ int smem_buffer_elts_(const Params &params) {
418
+ int M = round_up(params.M, 4);
419
+ return 2*E + 2*M;
420
+ }
421
+
422
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
423
+
424
+ template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
425
+ __global__
426
+ void lmha_kernel(Lmha_params<float> params) {
427
+
428
+ // Make sure E is a multiple of 4.
429
+ static_assert(E % 4 == 0, "");
430
+
431
+ // The amount of shared memory per buffer (2 buffers for double-buffering).
432
+ const int smem_buffer_elts = smem_buffer_elts_<E>(params);
433
+ // The M dimension for shared memory.
434
+ const int M = round_up(params.M, 4);
435
+
436
+ // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
437
+ extern __shared__ float smem_[];
438
+
439
+ // The various shared memory buffers.
440
+ float *smem_q = &smem_[0*E];
441
+ float *smem_k = &smem_[1*E];
442
+ float *smem_v = &smem_[2*E];
443
+ float *smem_o = &smem_[2*E + M];
444
+
445
+ // The index of the shared memory buffer (for double-buffering).
446
+ int smem_curr = 0;
447
+
448
+ // The sequence processed by that block.
449
+ const int bi = blockIdx.y;
450
+ // The head processed by that block.
451
+ const int hi = blockIdx.x;
452
+
453
+ // The linear index of the thread.
454
+ const int tidx = threadIdx.x;
455
+
456
+ // The offset to the position loaded by the thread in Q.
457
+ int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
458
+ // The offset to the position loaded by the thread in K.
459
+ int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
460
+
461
+ // If we walk backward, account for the extra offset.
462
+ if( GO_BACKWARD ) {
463
+ offset_q += (params.L-1)*params.q_stride_L;
464
+ offset_k += (params.L-1)*params.k_stride_L;
465
+ }
466
+
467
+ // Determine the base pointers for Q and K.
468
+ const float *ptr_q = &params.q[offset_q];
469
+ const float *ptr_k = &params.k[offset_k];
470
+
471
+ // The offset to the position loaded by the thread in V and O.
472
+ int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
473
+ int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;
474
+
475
+ // If we walk backward, account for the extra offset.
476
+ if( GO_BACKWARD ) {
477
+ offset_v += (params.L-1)*params.v_stride_L;
478
+ offset_o += (params.L-1)*params.o_stride_L;
479
+ }
480
+
481
+ // Determine the base pointers for V.
482
+ const float *ptr_v = &params.v[offset_v];
483
+
484
+ // Is it an active Q/K thread?
485
+ const int active_qk = tidx < params.E;
486
+
487
+ // Trigger the memory loads for Q and K.
488
+ float ldg_q = 0.f, ldg_k = 0.f;
489
+ if( active_qk ) {
490
+ ldg_q = *ptr_q;
491
+ ldg_k = *ptr_k;
492
+ }
493
+
494
+ // Is it an active V thread?
495
+ const int active_v = tidx < params.M;
496
+
497
+ // Trigger the memory loads for V.
498
+ float ldg_v = 0.f;
499
+ if( active_v ) {
500
+ ldg_v = *ptr_v;
501
+ }
502
+
503
+ // Move the load pointers.
504
+ if( GO_BACKWARD ) {
505
+ ptr_q -= params.q_stride_L;
506
+ ptr_k -= params.k_stride_L;
507
+ ptr_v -= params.v_stride_L;
508
+ } else {
509
+ ptr_q += params.q_stride_L;
510
+ ptr_k += params.k_stride_L;
511
+ ptr_v += params.v_stride_L;
512
+ }
513
+
514
+ // The number of FLOAT4s per head.
515
+ constexpr int FLOAT4s_PER_HEAD = E / 4;
516
+ // The number of FLOAT4s per thread.
517
+ constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
518
+
519
+ // The storage for the K*V^T values.
520
+ float4 kv[FLOAT4s_PER_THREAD];
521
+ #pragma unroll
522
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
523
+ kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
524
+ }
525
+
526
+ // The output pointer.
527
+ float *out_ptr = &params.out[offset_o];
528
+
529
+ // Store to shared memory Q and K.
530
+ if( tidx < E ) {
531
+ smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
532
+ smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
533
+ }
534
+
535
+ // Store to shared memory V. All threads store valid values.
536
+ if( tidx < M ) {
537
+ smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
538
+ }
539
+
540
+ // The position of the thread in the V dimension.
541
+ int vo = tidx / THREADS_PER_HEAD;
542
+ int vi = tidx % THREADS_PER_HEAD;
543
+
544
+ // Iterate over the timesteps.
545
+ for( int ti = 0; ti < params.L; ++ti ) {
546
+
547
+ // Is it the last iteration?
548
+ int is_last = ti == params.L - 1;
549
+
550
+ // Trigger the next loads for Q and K.
551
+ if( !is_last && active_qk ) {
552
+ ldg_q = *ptr_q;
553
+ ldg_k = *ptr_k;
554
+ }
555
+
556
+ // Trigger the next loads for V.
557
+ if( !is_last && active_v ) {
558
+ ldg_v = *ptr_v;
559
+ }
560
+
561
+ // Move the load pointers.
562
+ if( GO_BACKWARD ) {
563
+ ptr_q -= params.q_stride_L;
564
+ ptr_k -= params.k_stride_L;
565
+ ptr_v -= params.v_stride_L;
566
+ } else {
567
+ ptr_q += params.q_stride_L;
568
+ ptr_k += params.k_stride_L;
569
+ ptr_v += params.v_stride_L;
570
+ }
571
+
572
+ // Make sure the data is in shared memory.
573
+ __syncthreads();
574
+
575
+ // Each thread loads 4 values from K.
576
+ float4 k[FLOAT4s_PER_THREAD];
577
+ #pragma unroll
578
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
579
+ int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
580
+ k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
581
+ }
582
+
583
+ // Each thread loads a single V value.
584
+ float v = 0.f;
585
+ if( vo < params.M ) {
586
+ v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
587
+ }
588
+
589
+ // Update the K*V^T product.
590
+ #pragma unroll
591
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
592
+ kv[ii].x += k[ii].x * v;
593
+ kv[ii].y += k[ii].y * v;
594
+ kv[ii].z += k[ii].z * v;
595
+ kv[ii].w += k[ii].w * v;
596
+ }
597
+
598
+ // Load the Q values from shared memory.
599
+ float4 q[FLOAT4s_PER_THREAD];
600
+ #pragma unroll
601
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
602
+ int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
603
+ q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
604
+ }
605
+
606
+ // Compute the partial output value for that thread.
607
+ float sum = 0.f;
608
+ #pragma unroll
609
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
610
+ sum += q[ii].x * kv[ii].x;
611
+ sum += q[ii].y * kv[ii].y;
612
+ sum += q[ii].z * kv[ii].z;
613
+ sum += q[ii].w * kv[ii].w;
614
+ }
615
+
616
+ // Finalize the computation of the sum (if we have more than 1 thread per head).
617
+ if( THREADS_PER_HEAD > 1 ) {
618
+
619
+ // Finalize the sum for each head.
620
+ #pragma unroll
621
+ for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
622
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
623
+ }
624
+
625
+ // Store to shared memory.
626
+ if( vo < M && vi == 0 ) {
627
+ smem_o[smem_curr*smem_buffer_elts + vo] = sum;
628
+ }
629
+
630
+ // Make sure the data is in shared memory.
631
+ __syncthreads();
632
+
633
+ // Active threads read the data to store.
634
+ if( active_v ) {
635
+ sum = smem_o[smem_curr*smem_buffer_elts + tidx];
636
+ }
637
+
638
+ } // THREADS_PER_HEAD > 1.
639
+
640
+ // Store the output. All the threads are active.
641
+ if( active_v ) {
642
+ *out_ptr = sum;
643
+ }
644
+
645
+ // Move to next location.
646
+ if( GO_BACKWARD ) {
647
+ out_ptr -= params.o_stride_L;
648
+ } else {
649
+ out_ptr += params.o_stride_L;
650
+ }
651
+
652
+ // Move the shared memory buffer.
653
+ smem_curr = (smem_curr + 1) % 2;
654
+
655
+ // Store to shared memory for Q and K.
656
+ if( !is_last && tidx < E ) {
657
+ smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
658
+ smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
659
+ }
660
+
661
+ // Store to shared memory for V.
662
+ if( !is_last && tidx < M ) {
663
+ smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
664
+ }
665
+ }
666
+ }
667
+
668
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
669
+
670
+ template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
671
+ int lmha_(const Lmha_params<float> &params) {
672
+ // The M dimension rounded up to 4.
673
+ int M = round_up(params.M, 4);
674
+
675
+ // The number of threads in the block.
676
+ int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
677
+ if( block > 512 || params.B > 65535 ) {
678
+ return 1;
679
+ }
680
+
681
+ // Prepare the kernel.
682
+ dim3 grid(params.H, params.B);
683
+ size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
684
+ lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
685
+ return 0;
686
+ }
687
+
688
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
689
+
690
+ template< bool GO_BACKWARD >
691
+ int lmha(const Lmha_params<float> &params) {
692
+ int blocks = params.B * params.H;
693
+ int res = 1;
694
+ if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
695
+ if( params.E <= 32 ) {
696
+ res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
697
+ } else if( params.E <= 64 ) {
698
+ res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
699
+ } else if( params.E <= 128 ) {
700
+ res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
701
+ } else if( params.E <= 256 ) {
702
+ res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
703
+ }
704
+ } else {
705
+ if( params.E <= 32 ) {
706
+ res = lmha_< 32, 1, GO_BACKWARD>(params);
707
+ } else if( params.E <= 48 ) {
708
+ res = lmha_< 48, 1, GO_BACKWARD>(params);
709
+ } else if( params.E <= 64 ) {
710
+ res = lmha_< 64, 1, GO_BACKWARD>(params);
711
+ } else if( params.E <= 128 ) {
712
+ res = lmha_<128, 2, GO_BACKWARD>(params);
713
+ } else if( params.E <= 256 ) {
714
+ res = lmha_<256, 4, GO_BACKWARD>(params);
715
+ }
716
+ }
717
+ return res;
718
+ }
719
+
720
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
721
+
722
+ template< typename T >
723
+ inline void set_params(Lmha_params<T> &params,
724
+ const torch::Tensor q,
725
+ const torch::Tensor k,
726
+ const torch::Tensor v,
727
+ torch::Tensor o) {
728
+
729
+ // Define the pointers.
730
+ params.out = o.data_ptr<T>();
731
+ params.q = q.data_ptr<T>();
732
+ params.k = k.data_ptr<T>();
733
+ params.v = v.data_ptr<T>();
734
+
735
+ // Define the strides.
736
+ params.q_stride_B = (int) q.stride(0);
737
+ params.q_stride_H = (int) q.stride(1);
738
+ params.q_stride_L = (int) q.stride(2);
739
+ params.k_stride_B = (int) k.stride(0);
740
+ params.k_stride_H = (int) k.stride(1);
741
+ params.k_stride_L = (int) k.stride(2);
742
+ params.v_stride_B = (int) v.stride(0);
743
+ params.v_stride_H = (int) v.stride(1);
744
+ params.v_stride_L = (int) v.stride(2);
745
+ params.o_stride_B = (int) o.stride(0);
746
+ params.o_stride_H = (int) o.stride(1);
747
+ params.o_stride_L = (int) o.stride(2);
748
+
749
+ // Extract the dimensions.
750
+ int N = q.size(0);
751
+ int H = q.size(1);
752
+ int L = q.size(2);
753
+ int E = q.size(3);
754
+ int M = v.size(3);
755
+
756
+ params.B = N;
757
+ params.L = L;
758
+ params.H = H;
759
+ params.E = E;
760
+ params.M = M;
761
+ }
762
+
763
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
764
+
765
+ int lmha_fwd(const torch::Tensor queries,
766
+ const torch::Tensor keys,
767
+ const torch::Tensor values,
768
+ torch::Tensor product) {
769
+
770
+ // Make sure that we are using the correct GPU device
771
+ torch::DeviceGuard _guard(queries.device());
772
+
773
+ // Make sure the inner-most dimension of the tensors is packed.
774
+ assert(queries.stride(3) == 1);
775
+ assert(keys .stride(3) == 1);
776
+ assert(values .stride(3) == 1);
777
+ assert(product.stride(3) == 1);
778
+
779
+ // Extract the dimensions.
780
+ int N = queries.size(0);
781
+ int H = queries.size(1);
782
+ int L = queries.size(2);
783
+ int E = queries.size(3);
784
+ int M = values.size (3);
785
+
786
+ // The structure of params.
787
+ Lmha_params<float> params;
788
+ set_params(params, queries, keys, values, product);
789
+
790
+ // Launch the kernel.
791
+ return lmha<false>(params);
792
+ }
793
+
794
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
795
+
796
+ template< typename T >
797
+ struct Lmha_bwd_params {
798
+
799
+ // The output buffer for K. Dimensions [B, H, L, D].
800
+ T *out_k;
801
+ // The output buffer for V. Dimensions [B, H, L, D].
802
+ T *out_v;
803
+
804
+ // The input Qs. Dimensions [B, H, L, D].
805
+ const T *q;
806
+ // The input Ks. Dimensions [B, H, L, D].
807
+ const T *k;
808
+ // The input Vs. Dimensions [B, H, L, D].
809
+ const T *v;
810
+ // The input Gs. Dimensions [B, H, L, D].
811
+ const T *g;
812
+
813
+ // The dimensions.
814
+ int B, L, H, M, E;
815
+
816
+ // The strides for the input tensors.
817
+ int q_stride_B, q_stride_L, q_stride_H;
818
+ int k_stride_B, k_stride_L, k_stride_H;
819
+ int v_stride_B, v_stride_L, v_stride_H;
820
+ int g_stride_B, g_stride_L, g_stride_H;
821
+
822
+ // The strides for the outputs.
823
+ int out_k_stride_B, out_k_stride_L, out_k_stride_H;
824
+ int out_v_stride_B, out_v_stride_L, out_v_stride_H;
825
+ };
826
+
827
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
828
+
829
+ template< int D, int THREADS_PER_HEAD >
830
+ __global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
831
+ void lmha_bwd_kernel(Lmha_bwd_params<float> params) {
832
+
833
+ // Make sure D is a multiple of 4.
834
+ static_assert(D % 4 == 0, "");
835
+
836
+ // The shared memory buffers.
837
+ __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];
838
+
839
+ // The index of the shared memory buffer (for double-buffering).
840
+ int smem_curr = 0;
841
+
842
+ // The sequence processed by that block.
843
+ const int bi = blockIdx.y;
844
+ // The head processed by that block.
845
+ const int hi = blockIdx.x;
846
+
847
+ // The linear index of the thread.
848
+ const int tidx = threadIdx.x;
849
+
850
+ // Split the threads into two slices.
851
+ int so = tidx / (D*THREADS_PER_HEAD);
852
+ int si = tidx % (D*THREADS_PER_HEAD);
853
+
854
+ // The strides for B/L/H for the Q/G tensors.
855
+ int qg_stride_B, qg_stride_L, qg_stride_H;
856
+ if( so == 0 ) {
857
+ qg_stride_B = params.q_stride_B;
858
+ qg_stride_L = params.q_stride_L;
859
+ qg_stride_H = params.q_stride_H;
860
+ } else {
861
+ qg_stride_B = params.g_stride_B;
862
+ qg_stride_L = params.g_stride_L;
863
+ qg_stride_H = params.g_stride_H;
864
+ }
865
+
866
+ // The strides for B/L/H for the K/V tensors.
867
+ int kv_stride_B, kv_stride_L, kv_stride_H;
868
+ if( so == 0 ) {
869
+ kv_stride_B = params.k_stride_B;
870
+ kv_stride_L = params.k_stride_L;
871
+ kv_stride_H = params.k_stride_H;
872
+ } else {
873
+ kv_stride_B = params.v_stride_B;
874
+ kv_stride_L = params.v_stride_L;
875
+ kv_stride_H = params.v_stride_H;
876
+ }
877
+
878
+ // The hidden size.
879
+ int hidden_size_per_head = 0;
880
+ if( so == 0 ) {
881
+ hidden_size_per_head = params.E;
882
+ } else {
883
+ hidden_size_per_head = params.M;
884
+ }
885
+
886
+ // Where to start reading from.
887
+ int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
888
+ int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;
889
+
890
+ // We walk backward, account for the extra offset.
891
+ offset_qg += (params.L-1)*qg_stride_L;
892
+ offset_kv += (params.L-1)*kv_stride_L;
893
+
894
+ // Determine the base pointers for Q, K, V and G.
895
+ const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
896
+ const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];
897
+
898
+ // Is it an active thread?
899
+ const int active = si < hidden_size_per_head;
900
+
901
+ // Trigger the memory loads for Q, K, V and G.
902
+ float ldg_qg = 0.f, ldg_kv = 0.f;
903
+ if( active ) {
904
+ ldg_qg = *ptr_qg;
905
+ ldg_kv = *ptr_kv;
906
+ }
907
+
908
+ // Move the load pointers (backward).
909
+ ptr_qg -= qg_stride_L;
910
+ ptr_kv -= kv_stride_L;
911
+
912
+ // The number of FLOAT4s per head.
913
+ constexpr int FLOAT4s_PER_HEAD = D / 4;
914
+ // The number of FLOAT4s per thread.
915
+ constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
916
+
917
+ // The storage for the G*Q^T or Q^T*G values.
918
+ float4 gq[FLOAT4s_PER_THREAD];
919
+ #pragma unroll
920
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
921
+ gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
922
+ }
923
+
924
+ // The strides for B/L/H for the K/V tensors.
925
+ int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
926
+ if( so == 0 ) {
927
+ out_kv_stride_B = params.out_k_stride_B;
928
+ out_kv_stride_L = params.out_k_stride_L;
929
+ out_kv_stride_H = params.out_k_stride_H;
930
+ } else {
931
+ out_kv_stride_B = params.out_v_stride_B;
932
+ out_kv_stride_L = params.out_v_stride_L;
933
+ out_kv_stride_H = params.out_v_stride_H;
934
+ }
935
+
936
+ // Where to start reading from.
937
+ int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;
938
+
939
+ // We walk backward, account for the extra offset.
940
+ offset_out_kv += (params.L-1)*out_kv_stride_L;
941
+
942
+ // The output pointer.
943
+ float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];
944
+
945
+ // Store to shared memory.
946
+ if( si < D ) {
947
+ smem_[smem_curr].qg[so*D + si] = ldg_qg;
948
+ smem_[smem_curr].kv[so*D + si] = ldg_kv;
949
+ }
950
+
951
+ // The position of the thread in the output dimension.
952
+ int oo = si / THREADS_PER_HEAD % D;
953
+ int oi = si % THREADS_PER_HEAD * 4;
954
+
955
+ // Iterate over the timesteps.
956
+ for( int ti = 0; ti < params.L; ++ti ) {
957
+
958
+ // Is it the last iteration?
959
+ int is_last = ti == params.L - 1;
960
+
961
+ // Trigger the next loads.
962
+ if( !is_last && active ) {
963
+ ldg_qg = *ptr_qg;
964
+ ldg_kv = *ptr_kv;
965
+ }
966
+
967
+ // Move the load pointers.
968
+ ptr_qg -= qg_stride_L;
969
+ ptr_kv -= kv_stride_L;
970
+
971
+ // Make sure the data is in shared memory.
972
+ __syncthreads();
973
+
974
+ // Each thread loads 4 values from G or Q.
975
+ float4 g[FLOAT4s_PER_THREAD];
976
+ #pragma unroll
977
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
978
+ float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
979
+ g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
980
+ }
981
+
982
+ // Each thread loads a single from Q or G value.
983
+ float q = smem_[smem_curr].qg[so*D + oo];
984
+
985
+ // Update the G*Q^T or Q*G^T product.
986
+ #pragma unroll
987
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
988
+ gq[ii].x += g[ii].x * q;
989
+ gq[ii].y += g[ii].y * q;
990
+ gq[ii].z += g[ii].z * q;
991
+ gq[ii].w += g[ii].w * q;
992
+ }
993
+
994
+ // Load the V or K values from shared memory.
995
+ float4 v[FLOAT4s_PER_THREAD];
996
+ #pragma unroll
997
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
998
+ float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
999
+ v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
1000
+ }
1001
+
1002
+ // Compute the partial output value for that thread.
1003
+ float sum = 0.f;
1004
+ #pragma unroll
1005
+ for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
1006
+ sum += v[ii].x * gq[ii].x;
1007
+ sum += v[ii].y * gq[ii].y;
1008
+ sum += v[ii].z * gq[ii].z;
1009
+ sum += v[ii].w * gq[ii].w;
1010
+ }
1011
+
1012
+ // Finalize the computation of the sum (if we have more than 1 thread per head).
1013
+ if( THREADS_PER_HEAD > 1 ) {
1014
+
1015
+ // Finalize the sum for each head.
1016
+ #pragma unroll
1017
+ for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
1018
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
1019
+ }
1020
+
1021
+ // Store to shared memory.
1022
+ if( oi == 0 ) {
1023
+ smem_[smem_curr].out_kv[so*D + oo] = sum;
1024
+ }
1025
+
1026
+ // Make sure the data is in shared memory.
1027
+ __syncthreads();
1028
+
1029
+ // Active threads read the data to store.
1030
+ if( si < hidden_size_per_head ) {
1031
+ sum = smem_[smem_curr].out_kv[so*D + si];
1032
+ }
1033
+
1034
+ } // THREADS_PER_HEAD > 1.
1035
+
1036
+ // Store the output. All the threads are active.
1037
+ if( si < hidden_size_per_head ) {
1038
+ *ptr_out_kv = sum;
1039
+ }
1040
+
1041
+ // Move to next location.
1042
+ ptr_out_kv -= out_kv_stride_L;
1043
+
1044
+ // Move the shared memory buffer.
1045
+ smem_curr = (smem_curr + 1) % 2;
1046
+
1047
+ // Store to shared memory for Q and K.
1048
+ if( !is_last && si < D ) {
1049
+ smem_[smem_curr].qg[so*D + si] = ldg_qg;
1050
+ smem_[smem_curr].kv[so*D + si] = ldg_kv;
1051
+ }
1052
+ }
1053
+ }
1054
+
1055
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1056
+
1057
+ template< int D, int THREADS_PER_HEAD >
1058
+ int lmha_bwd_(const Lmha_bwd_params<float> &params) {
1059
+ int block = D*THREADS_PER_HEAD*2;
1060
+ if( block >= 1024 || params.B > 65535 ) {
1061
+ return 1;
1062
+ }
1063
+ dim3 grid(params.H, params.B);
1064
+ lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
1065
+ return 0;
1066
+ }
1067
+
1068
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1069
+
1070
+ int lmha_bwd(const Lmha_bwd_params<float> &params) {
1071
+ int blocks = params.B * params.H;
1072
+ if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
1073
+ return 1;
1074
+ }
1075
+
1076
+ int hidden_size_per_head = max(params.E, params.M);
1077
+ int res = 1;
1078
+ if( hidden_size_per_head <= 32 ) {
1079
+ res = lmha_bwd_< 32, 1>(params);
1080
+ } else if( hidden_size_per_head <= 64 ) {
1081
+ res = lmha_bwd_< 64, 1>(params);
1082
+ } else if( hidden_size_per_head <= 128 ) {
1083
+ res = lmha_bwd_<128, 2>(params);
1084
+ } else if( hidden_size_per_head <= 256 ) {
1085
+ res = lmha_bwd_<256, 4>(params);
1086
+ }
1087
+ return res;
1088
+ }
1089
+
1090
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1091
+
1092
+ int lmha_bwd(const torch::Tensor queries,
1093
+ const torch::Tensor keys,
1094
+ const torch::Tensor values,
1095
+ const torch::Tensor grad_out,
1096
+ torch::Tensor grad_queries,
1097
+ torch::Tensor grad_keys,
1098
+ torch::Tensor grad_values) {
1099
+
1100
+ // Make sure that we are using the correct GPU device
1101
+ torch::DeviceGuard _guard(queries.device());
1102
+
1103
+ // Make sure the inner-most dimension of the tensors is packed.
1104
+ assert(queries .stride(3) == 1);
1105
+ assert(keys .stride(3) == 1);
1106
+ assert(values .stride(3) == 1);
1107
+ assert(grad_out .stride(3) == 1);
1108
+ assert(grad_queries.stride(3) == 1);
1109
+ assert(grad_keys .stride(3) == 1);
1110
+ assert(grad_values .stride(3) == 1);
1111
+
1112
+ // Extract the dimensions.
1113
+ int N = queries.size(0);
1114
+ int H = queries.size(1);
1115
+ int L = queries.size(2);
1116
+ int E = queries.size(3);
1117
+ int M = values.size (3);
1118
+
1119
+ // Gradient on Q.
1120
+
1121
+ // The structure of params.
1122
+ Lmha_params<float> params;
1123
+ set_params(params, grad_out, values, keys, grad_queries);
1124
+
1125
+ // Launch the kernel.
1126
+ int res = lmha<false>(params);
1127
+ if( res ) {
1128
+ return res;
1129
+ }
1130
+
1131
+ // Gradient on K and V together.
1132
+
1133
+ Lmha_bwd_params<float> bwd_params;
1134
+ bwd_params.out_k = grad_keys.data_ptr<float>();
1135
+ bwd_params.out_v = grad_values.data_ptr<float>();
1136
+ bwd_params.q = queries.data_ptr<float>();
1137
+ bwd_params.k = keys.data_ptr<float>();
1138
+ bwd_params.v = values.data_ptr<float>();
1139
+ bwd_params.g = grad_out.data_ptr<float>();
1140
+
1141
+ bwd_params.B = N;
1142
+ bwd_params.L = L;
1143
+ bwd_params.H = H;
1144
+ bwd_params.E = E;
1145
+ bwd_params.M = M;
1146
+
1147
+ bwd_params.q_stride_B = queries.stride(0);
1148
+ bwd_params.q_stride_H = queries.stride(1);
1149
+ bwd_params.q_stride_L = queries.stride(2);
1150
+ bwd_params.k_stride_B = keys.stride(0);
1151
+ bwd_params.k_stride_H = keys.stride(1);
1152
+ bwd_params.k_stride_L = keys.stride(2);
1153
+ bwd_params.v_stride_B = values.stride(0);
1154
+ bwd_params.v_stride_H = values.stride(1);
1155
+ bwd_params.v_stride_L = values.stride(2);
1156
+ bwd_params.g_stride_B = grad_out.stride(0);
1157
+ bwd_params.g_stride_H = grad_out.stride(1);
1158
+ bwd_params.g_stride_L = grad_out.stride(2);
1159
+
1160
+ bwd_params.out_k_stride_B = grad_keys.stride(0);
1161
+ bwd_params.out_k_stride_H = grad_keys.stride(1);
1162
+ bwd_params.out_k_stride_L = grad_keys.stride(2);
1163
+ bwd_params.out_v_stride_B = grad_values.stride(0);
1164
+ bwd_params.out_v_stride_H = grad_values.stride(1);
1165
+ bwd_params.out_v_stride_L = grad_values.stride(2);
1166
+
1167
+ // Try to run the fused kernel.
1168
+ int fallback = lmha_bwd(bwd_params);
1169
+
1170
+ // If it failed, fallback on separate kernels for K and V.
1171
+ if( fallback ) {
1172
+
1173
+ // Gradient on K.
1174
+
1175
+ // Launch the kernel.
1176
+ set_params(params, values, grad_out, queries, grad_keys);
1177
+ res = lmha<true>(params);
1178
+ if( res ) {
1179
+ return res;
1180
+ }
1181
+
1182
+ // Gradient on V.
1183
+
1184
+ // Launch the kernel.
1185
+ set_params(params, keys, queries, grad_out, grad_values);
1186
+ return lmha<true>(params);
1187
+ }
1188
+
1189
+ // It worked...
1190
+ return 0;
1191
+ }
1192
+
1193
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1194
+
1195
+ } // namespace nvidia
1196
+ #endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
1197
+
1198
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1199
+
1200
+ typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;
1201
+
1202
+ #define E_BLOCK_SIZE 8
1203
+
1204
+ __global__ void causal_dot_product_kernel(
1205
+ const float_accessor queries,
1206
+ const float_accessor keys,
1207
+ const float_accessor values,
1208
+ float_accessor result,
1209
+ const int N,
1210
+ const int H,
1211
+ const int L,
1212
+ const int E,
1213
+ const int M
1214
+ ) {
1215
+ int n = blockIdx.y;
1216
+ int h = blockIdx.z;
1217
+
1218
+ int e_start = blockIdx.x * E_BLOCK_SIZE;
1219
+ int m = threadIdx.x % M;
1220
+
1221
+ extern __shared__ float shared_mem[];
1222
+ float* shared_kv = shared_mem;
1223
+
1224
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1225
+ shared_kv[m + e_local * M] = 0;
1226
+ }
1227
+
1228
+ for (int t=0; t<L; t++) {
1229
+ float res = 0;
1230
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1231
+ shared_kv[e_local*M + m] += keys[n][h][t][e_local + e_start] * values[n][h][t][m];
1232
+ res += queries[n][h][t][e_local + e_start] * shared_kv[e_local*M + m];
1233
+ }
1234
+ atomicAdd(
1235
+ &result[n][h][t][m],
1236
+ res
1237
+ );
1238
+ }
1239
+ }
1240
+
1241
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1242
+
1243
+ void causal_dot_product_(const torch::Tensor queries,
1244
+ const torch::Tensor keys,
1245
+ const torch::Tensor values,
1246
+ torch::Tensor product) {
1247
+ // Make sure that we are using the correct GPU device
1248
+ torch::DeviceGuard _guard(queries.device());
1249
+
1250
+ int N = queries.size(0);
1251
+ int H = queries.size(1);
1252
+ int L = queries.size(2);
1253
+ int E = queries.size(3);
1254
+ int M = values.size(3);
1255
+
1256
+ const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
1257
+
1258
+ dim3 blockDim(M, 1, 1);
1259
+ dim3 gridDim(blocks_per_sequence, N, H);
1260
+ const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
1261
+
1262
+ causal_dot_product_kernel<<<gridDim, blockDim, shared_mem_forward>>>(
1263
+ queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1264
+ keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1265
+ values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1266
+ product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1267
+ N, H, L, E, M
1268
+ );
1269
+ }
1270
+
1271
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1272
+
1273
+ void causal_dot_product(const torch::Tensor queries,
1274
+ const torch::Tensor keys,
1275
+ const torch::Tensor values,
1276
+ torch::Tensor product) {
1277
+ #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
1278
+ int fallback = nvidia::lmha_fwd(queries, keys, values, product);
1279
+ #else
1280
+ int fallback = 1;
1281
+ #endif
1282
+ if( fallback ) {
1283
+ causal_dot_product_(queries, keys, values, product);
1284
+ }
1285
+ }
1286
+
1287
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1288
+
1289
+ #define M_BLOCK_SIZE 4
1290
+
1291
+ // we need shared memory to store
1292
+ // kv
1293
+ // Backward direction
1294
+ // kv_backwards
1295
+ // Shared memory usage
1296
+ __global__ void causal_dot_backward_query_key_kernel(
1297
+ const float_accessor queries,
1298
+ const float_accessor keys,
1299
+ const float_accessor values,
1300
+ const float_accessor grad_out,
1301
+ float_accessor grad_queries,
1302
+ float_accessor grad_keys,
1303
+ int N,
1304
+ int H,
1305
+ int L,
1306
+ int E,
1307
+ int M
1308
+ ) {
1309
+ int n = blockIdx.y;
1310
+ int h = blockIdx.z;
1311
+
1312
+ int m_start = blockIdx.x * M_BLOCK_SIZE;
1313
+ int e = threadIdx.x % E;
1314
+
1315
+ extern __shared__ float shared_mem[];
1316
+ const int shared_kv_size = M_BLOCK_SIZE * E;
1317
+ float* shared_kv = shared_mem;
1318
+ float* shared_kv_bw = shared_mem + shared_kv_size;
1319
+
1320
+ for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
1321
+ shared_kv[m_local * E + e] = 0;
1322
+ shared_kv_bw[m_local * E + e] = 0;
1323
+ }
1324
+
1325
+ for (int l=0; l<L; l++) {
1326
+ float res = 0, res_bw = 0;
1327
+ int l_b = L - l - 1;
1328
+ for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
1329
+ shared_kv[m_local*E + e] += keys[n][h][l][e] * values[n][h][l][m_start + m_local];
1330
+ shared_kv_bw[m_local*E + e] += queries[n][h][l_b][e] * grad_out[n][h][l_b][m_start + m_local];
1331
+ res += grad_out[n][h][l][m_start + m_local] * shared_kv[m_local*E + e];
1332
+ res_bw += values[n][h][l_b][m_start + m_local] * shared_kv_bw[m_local*E + e];
1333
+ }
1334
+ atomicAdd(
1335
+ &grad_queries[n][h][l][e],
1336
+ res
1337
+ );
1338
+ atomicAdd(
1339
+ &grad_keys[n][h][l_b][e],
1340
+ res_bw
1341
+ );
1342
+ }
1343
+ }
1344
+
1345
+
1346
+ __global__ void causal_dot_backward_value_kernel(
1347
+ const float_accessor queries,
1348
+ const float_accessor keys,
1349
+ const float_accessor values,
1350
+ const float_accessor grad_out,
1351
+ float_accessor grad_keys,
1352
+ float_accessor grad_values,
1353
+ int N,
1354
+ int H,
1355
+ int L,
1356
+ int E,
1357
+ int M
1358
+ ) {
1359
+ int n = blockIdx.y;
1360
+ int h = blockIdx.z;
1361
+
1362
+ int e_start = blockIdx.x * E_BLOCK_SIZE;
1363
+ int m = threadIdx.x % M;
1364
+
1365
+ extern __shared__ float shared_mem[];
1366
+ float* shared_kv = shared_mem;
1367
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1368
+ shared_kv[m + e_local * M] = 0;
1369
+ }
1370
+
1371
+ for (int l = 0; l < L; l++) {
1372
+ int l_b = L - l -1;
1373
+ float res = 0;
1374
+ for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
1375
+ shared_kv[e_local*M + m] += queries[n][h][l_b][e_start + e_local] * grad_out[n][h][l_b][m];
1376
+ res += keys[n][h][l_b][e_start + e_local] * shared_kv[e_local*M + m];
1377
+ }
1378
+ atomicAdd(
1379
+ &grad_values[n][h][l_b][m],
1380
+ res
1381
+ );
1382
+ }
1383
+ }
1384
+
1385
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1386
+
1387
+ void causal_dot_backward_(const torch::Tensor queries,
1388
+ const torch::Tensor keys,
1389
+ const torch::Tensor values,
1390
+ const torch::Tensor grad_out,
1391
+ torch::Tensor grad_queries,
1392
+ torch::Tensor grad_keys,
1393
+ torch::Tensor grad_values) {
1394
+
1395
+ // Make sure that we are using the correct GPU device
1396
+ torch::DeviceGuard _guard(queries.device());
1397
+
1398
+ int N = queries.size(0);
1399
+ int H = queries.size(1);
1400
+ int L = queries.size(2);
1401
+ int E = queries.size(3);
1402
+ int M = values.size(3);
1403
+
1404
+ const int blocks_per_sequence = (M + M_BLOCK_SIZE - 1) / M_BLOCK_SIZE;
1405
+
1406
+ dim3 blockDim(E, 1, 1);
1407
+ dim3 gridDim(blocks_per_sequence, N, H);
1408
+ const int shared_mem_qk_backward = 2 * M_BLOCK_SIZE * E * sizeof(float);
1409
+
1410
+ causal_dot_backward_query_key_kernel<<<gridDim, blockDim, shared_mem_qk_backward>>>(
1411
+ queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1412
+ keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1413
+ values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1414
+ grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1415
+ grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1416
+ grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1417
+ N, H, L, E, M
1418
+ );
1419
+
1420
+ const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
1421
+
1422
+ dim3 blockDimv(M, 1, 1);
1423
+ dim3 gridDimv(blocks_per_sequence_value, N, H);
1424
+ const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float);
1425
+ causal_dot_backward_value_kernel<<<gridDimv, blockDimv, shared_mem_v_backward>>>(
1426
+ queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1427
+ keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1428
+ values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1429
+ grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1430
+ grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1431
+ grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
1432
+ N, H, L, E, M
1433
+ );
1434
+ }
1435
+
1436
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1437
+
1438
+ void causal_dot_backward(const torch::Tensor queries,
1439
+ const torch::Tensor keys,
1440
+ const torch::Tensor values,
1441
+ const torch::Tensor grad_out,
1442
+ torch::Tensor grad_queries,
1443
+ torch::Tensor grad_keys,
1444
+ torch::Tensor grad_values) {
1445
+ #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
1446
+ int fallback = nvidia::lmha_bwd(queries,
1447
+ keys,
1448
+ values,
1449
+ grad_out,
1450
+ grad_queries,
1451
+ grad_keys,
1452
+ grad_values);
1453
+ #else
1454
+ int fallback = 1;
1455
+ #endif
1456
+ if( fallback ) {
1457
+ // Make sure that the gradient tensors are 0. This is needed because the
1458
+ // bwd pass might have partially executed and filled in some values in
1459
+ // grad_queries or grad_keys.
1460
+ //
1461
+ // This adds a small overhead every time we have to fall back to the old
1462
+ // kernel for the backward pass.
1463
+ grad_queries.zero_();
1464
+ grad_keys.zero_();
1465
+ causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
1466
+ }
1467
+ }
1468
+
1469
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1470
+
1471
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1472
+ m.def(
1473
+ "causal_dot_product",
1474
+ &causal_dot_product,
1475
+ "Compute the weighted sum of values but attending only to previous "
1476
+ "values."
1477
+ );
1478
+ m.def(
1479
+ "causal_dot_backward",
1480
+ &causal_dot_backward,
1481
+ "Compute the gradients for the causal dot product."
1482
+ );
1483
+ }
csrc/setup.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ import torch
8
+ from setuptools import setup
9
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
10
+ import subprocess
11
+
12
+ def get_last_arch_torch():
13
+ arch = torch.cuda.get_arch_list()[-1]
14
+ print(f"Found arch: {arch} from existing torch installation")
15
+ return arch
16
+
17
+ def get_cuda_bare_metal_version(cuda_dir):
18
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
19
+ output = raw_output.split()
20
+ release_idx = output.index("release") + 1
21
+ release = output[release_idx].split(".")
22
+ bare_metal_major = release[0]
23
+ bare_metal_minor = release[1][0]
24
+ return raw_output, bare_metal_major, bare_metal_minor
25
+
26
+ def append_nvcc_threads(nvcc_extra_args):
27
+ _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
28
+ if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
29
+ return nvcc_extra_args + ["--threads", "4"]
30
+ return nvcc_extra_args
31
+
32
+ arch = get_last_arch_torch()
33
+ sm_num = arch[-2:]
34
+ cc_flag = ['--generate-code=arch=compute_90,code=compute_90'] # for H100
35
+ # cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
36
+ # cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
37
+ # cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
38
+ # cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
39
+
40
+ setup(
41
+ name='causal_attention_cuda_cpp',
42
+ ext_modules=[
43
+ CUDAExtension('causal_attention_cuda', [
44
+ # 'causal_attention.cpp',
45
+ 'causal_attention_cuda.cu',
46
+ ],
47
+ extra_compile_args={'cxx': ['-O3'],
48
+ 'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag)
49
+ })
50
+ ],
51
+ cmdclass={
52
+ 'build_ext': BuildExtension
53
+ })
src/__init__.py ADDED
File without changes
src/dataloaders/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load dataloaders
3
+ """
4
+ import importlib
5
+
6
+
7
+ def load_data(dataset_config: dict, dataloader_config: dict):
8
+ """Return dataloaders from dataset_config"""
9
+ try:
10
+ dataset_module = importlib.import_module(f'dataloaders.{dataset_config["name"]}')
11
+ except Exception:
12
+ try:
13
+ dataset_module = importlib.import_module(f'src.dataloaders.{dataset_config["name"]}')
14
+ except Exception as e2:
15
+ print(e2)
16
+ try: # e.g., tasks like GLUE where name is benchmark and path specifies the dataset / task
17
+ dataset_module = importlib.import_module(f'dataloaders.{dataset_config["path"]}')
18
+ except Exception as e3:
19
+ print(f'Error from {dataset_config}')
20
+ raise e3
21
+ _load_data = getattr(dataset_module, 'load_data')
22
+ return _load_data(**dataset_config, **dataloader_config)
src/dataloaders/alpaca_clean.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Alpaca training dataloaders
3
+
4
+ We adopt the original prompt template; goes something like:
5
+ ```
6
+ Below is an instruction that describes a task.
7
+ Write a response that appropriately completes the request.
8
+ ### Instruction:
9
+ {instruction}
10
+
11
+ ### Response:
12
+ {response}
13
+ ```
14
+ See `PROMPT_DICT` for more.
15
+ """
16
+ from functools import partial
17
+ from os.path import join
18
+
19
+ from datasets import load_metric, load_dataset
20
+
21
+ from .utils import (
22
+ get_lm_loader, get_seq2seq_loader,
23
+ convert_to_hf_dataset,
24
+ get_tokenizer_from_config,
25
+ download_scrolls_metric as download_metric
26
+ )
27
+ from .utils.packing import ConcatDataset
28
+
29
+
30
+ PROMPT_DICT = {
31
+ "prompt_input": (
32
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
33
+ "Write a response that appropriately completes the request.\n\n"
34
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
35
+ ),
36
+ "prompt_no_input": (
37
+ "Below is an instruction that describes a task. "
38
+ "Write a response that appropriately completes the request.\n\n"
39
+ "### Instruction:\n{instruction}\n\n### Response:\n"
40
+ ),
41
+ }
42
+
43
+
44
+ def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
45
+ preprocess_config: dict, **loader_kwargs: any):
46
+ """
47
+ Shared function to load dataset from experiment config
48
+ -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml
49
+ """
50
+ # Misc. setup
51
+ cache_dir = dataset_config['cache_dir']
52
+ input_len = dataset_config['chunk_size']
53
+ concat_data = dataset_config['concat_data']
54
+
55
+ tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
56
+ tokenizer_name = tokenizer_name.split('/')[-1]
57
+ # save_path = join(cache_dir, f'{name}_{tokenizer_name}')
58
+
59
+ # Setup tokenizer
60
+ tokenizer = get_tokenizer_from_config(pretrained_model_config)
61
+ if tokenizer.pad_token is None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+ print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
64
+
65
+ tokenizer.padding_side = 'left' # for decoder-only generation
66
+ # Get initial data
67
+ ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs']
68
+ dataset = load_dataset(
69
+ **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}
70
+ )
71
+ if dataset_config['name'] == 'samsum': # hack
72
+ dataset = dataset.rename_column('dialogue', 'input')
73
+ dataset = dataset.rename_column('summary', 'output')
74
+ _instruction = 'Summarize this dialogue.'
75
+ for split in dataset.keys():
76
+ dataset[split] = dataset[split].add_column(
77
+ 'instruction', [_instruction] * len(dataset[split])
78
+ )
79
+ train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
80
+ dataset = train_set # hack to work with below code
81
+ else:
82
+ dataset = dataset['train']
83
+ train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir)
84
+ val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
85
+ test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
86
+
87
+ # Convert to dicts of {input_ids, attention_mask, labels}
88
+ train_set = train_set.map(
89
+ partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
90
+ remove_columns=list(dataset.features),) # load_from_cache_file=False)
91
+ val_set = val_set.map(
92
+ partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
93
+ remove_columns=list(dataset.features),) # load_from_cache_file=False)
94
+ test_set = test_set.map(
95
+ partial(template_and_tokenize, tokenizer=tokenizer, include_label=False),
96
+ remove_columns=list(dataset.features),) # load_from_cache_file=False)
97
+
98
+ # Chunk together train and val sets
99
+ if concat_data:
100
+ train_set = ConcatDataset(train_set, chunk_size=input_len)
101
+ val_set = ConcatDataset(val_set, chunk_size=input_len)
102
+
103
+ # Get dataloaders
104
+ dataloaders = {
105
+ 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
106
+ 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
107
+ 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
108
+ }
109
+ # Evaluation metric
110
+ try:
111
+ metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
112
+ except Exception as e:
113
+ print(f'Error loading metric: {e}')
114
+ metric = None
115
+
116
+ # Finishing touches
117
+ for k, v in dataloaders.items(): # Make tokenizer accessible
118
+ dataloaders[k].dataset.tokenizer = tokenizer
119
+ dataloaders[k].dataset.metric = metric
120
+ return dataloaders
121
+
122
+
123
+ def template_and_tokenize(sample, tokenizer, include_label: bool = True):
124
+ """
125
+ Format dataset context and answers into single-sequence prompts
126
+ """
127
+ if sample.get('input', '') == '':
128
+ prompt = PROMPT_DICT["prompt_no_input"].format_map(sample)
129
+ else:
130
+ prompt = PROMPT_DICT["prompt_input"].format_map(sample)
131
+
132
+ prompt = tokenizer.encode(prompt, add_special_tokens=True)
133
+ if include_label:
134
+ answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
135
+ add_special_tokens=False)
136
+ target = None
137
+ else:
138
+ answer = []
139
+ target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
140
+ add_special_tokens=False)
141
+ input_ids = prompt + answer
142
+ attn_mask = [1] * len(input_ids)
143
+
144
+ sample = {
145
+ "input_ids": input_ids,
146
+ "attention_mask" : attn_mask,
147
+ "labels": [-100] * len(prompt) + answer if include_label else target,
148
+ }
149
+ return sample
src/dataloaders/alpaca_clean_instruct.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Alpaca Clean dataset with Llama3-Instruct prompt formatting
3
+ """
4
+
5
+ from functools import partial
6
+ from os.path import join
7
+
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset, DataLoader
13
+
14
+ from datasets import load_metric, load_dataset
15
+ from transformers import AutoTokenizer
16
+ from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding
17
+
18
+ from .utils import (
19
+ get_lm_loader, get_seq2seq_loader,
20
+ convert_to_hf_dataset,
21
+ get_tokenizer_from_config,
22
+ download_scrolls_metric as download_metric
23
+ )
24
+ from .utils.packing import ConcatDataset
25
+
26
+
27
+ SYSTEM_PROMPT = "You are a helpful AI assistant who always responds to appropriately complete a user's request."
28
+
29
+
30
+ def encode_response(response: str, tokenizer) -> list[int]:
31
+ tokens = tokenizer.encode(response.strip(), add_special_tokens=False)
32
+ # For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
33
+ tokens.append(tokenizer.eos_token_id)
34
+ try: # Llama 3 Instruct
35
+ tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
36
+ except KeyError:
37
+ pass
38
+ return tokens
39
+
40
+
41
+ def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
42
+ preprocess_config: dict, **loader_kwargs: any):
43
+
44
+ # Misc. setup
45
+ cache_dir = dataset_config['cache_dir']
46
+ input_len = dataset_config['chunk_size']
47
+ concat_data = dataset_config['concat_data']
48
+ load_from_cache_file = False # False if want to retokenize dataset
49
+
50
+ # Hard-code system prompt handling
51
+ if 'istral' in pretrained_model_config['pretrained_model_name_or_path']:
52
+ system_prompt = ''
53
+ else:
54
+ system_prompt = SYSTEM_PROMPT
55
+
56
+ tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
57
+ tokenizer_name = tokenizer_name.split('/')[-1]
58
+ save_path = join(cache_dir, f'{name}_{tokenizer_name}')
59
+
60
+ # Setup tokenizer
61
+ tokenizer = get_tokenizer_from_config(pretrained_model_config)
62
+ if tokenizer.pad_token is None:
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+ print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
65
+
66
+ tokenizer.padding_side = 'left' # for decoder-only generation
67
+
68
+ # Get initial data
69
+ ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'system_prompt', 'name']
70
+ train_set = load_dataset(
71
+ **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
72
+ split='train[100:-100]',
73
+ )
74
+ val_set = load_dataset( # we just use this dataset as a validation set
75
+ **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
76
+ split='train[:100]+train[-100:]',
77
+ )
78
+ test_set = load_dataset(
79
+ **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
80
+ split='train[:100]+train[-100:]',
81
+ )
82
+
83
+ # Convert to dicts of {input_ids, attention_mask, labels}
84
+ train_set = train_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
85
+ include_label=True, system_prompt=system_prompt),
86
+ remove_columns=list(train_set.features),
87
+ load_from_cache_file=load_from_cache_file)
88
+ val_set = val_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
89
+ include_label=True, system_prompt=system_prompt),
90
+ remove_columns=list(val_set.features),
91
+ load_from_cache_file=load_from_cache_file)
92
+ test_set = test_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
93
+ include_label=False, system_prompt=system_prompt),
94
+ remove_columns=list(test_set.features),
95
+ load_from_cache_file=load_from_cache_file)
96
+
97
+ # Chunk together train and val sets
98
+ if concat_data:
99
+ train_set = ConcatDataset(train_set, chunk_size=input_len)
100
+ val_set = ConcatDataset(val_set, chunk_size=input_len)
101
+
102
+ # Get dataloaders
103
+ dataloaders = {
104
+ 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
105
+ 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
106
+ 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
107
+ }
108
+ # Evaluation metric
109
+ metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
110
+
111
+ # Finishing touches
112
+ for k, v in dataloaders.items(): # Make tokenizer accessible
113
+ dataloaders[k].dataset.tokenizer = tokenizer
114
+ dataloaders[k].dataset.metric = metric
115
+ return dataloaders
116
+
117
+
118
+ def template_and_tokenize(sample, tokenizer, include_label: bool = True,
119
+ system_prompt: str = None):
120
+ if system_prompt is None:
121
+ system_prompt = SYSTEM_PROMPT
122
+
123
+ prompt = sample['instruction']
124
+ if sample['input'] != '':
125
+ prompt += f"\n\n{sample['input']}"
126
+
127
+ messages = [
128
+ {"role": "system", "content": system_prompt},
129
+ ] if system_prompt != '' else []
130
+ messages.append({"role": "user", "content": prompt})
131
+ prompt_ids = tokenizer.apply_chat_template(
132
+ messages, tokenize=True, add_generation_prompt=True,
133
+ )
134
+ if include_label:
135
+ answer = encode_response(sample['output'], tokenizer)
136
+ else:
137
+ answer = []
138
+ target = encode_response(sample['output'], tokenizer)
139
+
140
+ input_ids = prompt_ids + answer
141
+ attn_mask = [1] * len(input_ids)
142
+ sample = {
143
+ "input_ids": input_ids,
144
+ "attention_mask" : attn_mask,
145
+ "labels": [-100] * len(prompt_ids) + answer if include_label else target,
146
+ }
147
+ return sample
148
+
src/dataloaders/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Helper functions dataset setup and loading
3
+ """
4
+ from .setup import *
src/dataloaders/utils/llama3.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data utils for Llama3
3
+ """
4
+
5
+ def encode_header(message: str, tokenizer) -> list[int]:
6
+ tokens = []
7
+ tokens.append(tokenizer.get_added_vocab()["<|start_header_id|>"])
8
+ tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
9
+ tokens.append(tokenizer.get_added_vocab()["<|end_header_id|>"])
10
+ tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
11
+ return tokens
12
+
13
+
14
+ def encode_message(message: str, tokenizer, include_header: bool = True) -> list[int]:
15
+ tokens = encode_header(message, tokenizer) if include_header else []
16
+ tokens.extend(
17
+ tokenizer.encode(message["content"].strip(), add_special_tokens=False)
18
+ )
19
+ tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
20
+ return tokens
21
+
22
+
23
+ def template_and_tokenize(sample, tokenizer, include_label: bool = True,
24
+ system_prompt: str = None):
25
+ if system_prompt is not None:
26
+ dialog = [{'role': 'system', 'content': system_prompt}]
27
+ else:
28
+ dialog = []
29
+
30
+ chat = []
31
+ instruction = sample['instruction']
32
+ if sample['input'] != '':
33
+ instruction += f"\n\n{sample['input']}"
34
+ dialog.extend([
35
+ {'role': 'user', 'content': instruction},
36
+ {'role': 'assistant', 'content': sample['output']},
37
+ ])
38
+
39
+ prompt = []
40
+ prompt.append(tokenizer.get_added_vocab()["<|begin_of_text|>"])
41
+ for message in dialog[:-1]:
42
+ prompt.extend(encode_message(message, tokenizer))
43
+
44
+ if include_label:
45
+ answer = encode_message(dialog[-1], tokenizer)
46
+ answer.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
47
+ else:
48
+ answer = []
49
+ target = encode_message(dialog[-1], tokenizer, include_header=False)
50
+ target.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
51
+ # Add the start of an assistant message for the model to complete.
52
+ prompt.extend(encode_header({"role": "assistant", "content": ""}, tokenizer))
53
+
54
+ input_ids = prompt + answer
55
+ attn_mask = [1] * len(input_ids)
56
+
57
+ sample = {
58
+ "input_ids": input_ids,
59
+ "attention_mask" : attn_mask,
60
+ "labels": [-100] * len(prompt) + answer if include_label else target,
61
+ }
62
+ return sample
src/dataloaders/utils/packing.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+ """
4
+ Copied from https://github.com/meta-llama/llama-recipes/blob/9b3dabcaac78980eae40005bbc8b1a8276c82af3/src/llama_recipes/data/concatenator.py#L1
5
+ """
6
+ import random
7
+ from itertools import chain
8
+ from tqdm import tqdm
9
+
10
+
11
+ from torch.utils.data import Dataset
12
+
13
+
14
+ class Concatenator(object):
15
+ def __init__(self, chunk_size=2048):
16
+ self.chunk_size=chunk_size
17
+ self.residual = {"input_ids": [], "attention_mask": []}
18
+
19
+ def __call__(self, batch):
20
+ concatenated_samples = {
21
+ k: v + list(chain(*batch[k])) for k, v in self.residual.items()
22
+ }
23
+
24
+ total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
25
+
26
+ if total_length >= self.chunk_size:
27
+ chunk_num = total_length // self.chunk_size
28
+ result = {
29
+ k: [
30
+ v[i : i + self.chunk_size]
31
+ for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
32
+ ]
33
+ for k, v in concatenated_samples.items()
34
+ }
35
+ self.residual = {
36
+ k: v[(chunk_num * self.chunk_size) :]
37
+ for k, v in concatenated_samples.items()
38
+ }
39
+ else:
40
+ result = concatenated_samples
41
+ self.residual = {k: [] for k in concatenated_samples.keys()}
42
+
43
+ result["labels"] = result["input_ids"].copy()
44
+
45
+ return result
46
+
47
+ class ConcatDataset(Dataset):
48
+ """
49
+ Concatenates or packs samples of a dataset into chunks of size `chunk_size`
50
+ """
51
+ def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None:
52
+ self.dataset = dataset
53
+ self.chunk_size = chunk_size
54
+ self.samples = []
55
+ buffer = {
56
+ "input_ids": [],
57
+ "attention_mask": [],
58
+ "labels": [],
59
+ }
60
+ random.seed(seed)
61
+ for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
62
+ buffer = {k: v + sample[k] for k,v in buffer.items()}
63
+
64
+ while len(next(iter(buffer.values()))) > self.chunk_size:
65
+ self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
66
+ buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
67
+ # Slow hack, but filter out any samples without valid labels (all -100)
68
+ self.filtered_samples = []
69
+ for s in self.samples:
70
+ if sum(s['labels']) != chunk_size * -100:
71
+ self.filtered_samples.append(s)
72
+ if len(self.filtered_samples) < len(self.samples):
73
+ print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}')
74
+ print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples')
75
+
76
+ def __getitem__(self, idx):
77
+ return self.filtered_samples[idx]
78
+
79
+ def __len__(self):
80
+ return len(self.filtered_samples)
src/dataloaders/utils/setup.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helper functions dataset setup and loading
3
+ """
4
+ import os
5
+ from os.path import join
6
+ import shutil
7
+ import numpy as np
8
+
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ from datasets import Dataset as HFDataset
12
+ from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, LlamaTokenizer
14
+ from transformers import DataCollatorForSeq2Seq
15
+ # from transformers import DefaultDataCollator, DataCollatorWithPadding
16
+
17
+
18
+ def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer,
19
+ split: str, **loader_kwargs: any):
20
+ """
21
+ Get dataloader for seq2seq tasks (evaluation)
22
+ """
23
+ tokenizer.padding_side = 'right'
24
+ collate_fn = DataCollatorForSeq2Seq(
25
+ tokenizer, label_pad_token_id=-100, return_tensors='pt')
26
+ return DataLoader(
27
+ dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
28
+
29
+
30
+ def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer,
31
+ split: str, max_length: int = None, **loader_kwargs: any):
32
+ """
33
+ Get dataloader for language modeling (training)
34
+ -> Currently this ends up being the same as get_seq2seq_loader
35
+ """
36
+ # collate_fn = DefaultDataCollator(return_tensors='pt')
37
+ # collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,
38
+ # max_length=max_length, return_tensors='pt')
39
+ collate_fn = DataCollatorForSeq2Seq(
40
+ tokenizer, label_pad_token_id=-100, return_tensors='pt')
41
+ return DataLoader(
42
+ dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
43
+
44
+
45
+ def convert_to_hf_dataset(dataset, cache_dir: str):
46
+ """
47
+ Convert iterable dataset to HuggingFace HFDataset object
48
+ """
49
+ def gen():
50
+ for _, sample in enumerate(dataset):
51
+ yield sample # dataset[idx]
52
+ return HFDataset.from_generator(gen, cache_dir=cache_dir)
53
+
54
+
55
+ def get_tokenizer_from_config(model_config):
56
+ """
57
+ Get pretrained tokenizer based on (pretrained) model config
58
+ """
59
+ # Get tokenizer
60
+ if 'llama' in model_config['pretrained_model_name_or_path']:
61
+ try: # if we store locally
62
+ model_path = join(model_config['cache_dir'],
63
+ model_config['pretrained_model_name_or_path'])
64
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
65
+ except Exception as e:
66
+ try:
67
+ tokenizer = AutoTokenizer.from_pretrained(**model_config)
68
+ print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e)
69
+ print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)")
70
+ except Exception as e2:
71
+ print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2)
72
+ # tokenizer = LlamaTokenizer.from_pretrained(**model_config) # v4.43 errors with `*** TypeError: not a string`
73
+ elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']:
74
+ tokenizer = LlamaTokenizer.from_pretrained(**model_config) # hack where AutoTokenizer doesn't recognize
75
+ elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']:
76
+ tokenizer = AutoTokenizer.from_pretrained(**model_config)
77
+ else:
78
+ tokenizer = AutoTokenizer.from_pretrained(**model_config)
79
+ return tokenizer
80
+
81
+
82
+ def add_special_tokens_to_dataset(dataset, tokenizer):
83
+ """
84
+ Add special tokens as attributes to a dataset object
85
+ """
86
+ token_map = {k: v for k, v in tokenizer.special_tokens_map.items()}
87
+ special_ids = tokenizer.all_special_ids
88
+ for idx, k in enumerate(tokenizer.special_tokens_map.keys()):
89
+ token_map[f'{k}_id'] = special_ids[idx]
90
+ for k, v in token_map.items():
91
+ setattr(dataset, k, v)
92
+ return dataset
93
+
94
+
95
+ def train_test_split(samples: any, train_size: int, test_size: int, seed: int):
96
+ """
97
+ Split samples into train and test sets
98
+ """
99
+ try:
100
+ assert len(samples) == train_size + test_size
101
+ except Exception as e:
102
+ print(len(samples), train_size + test_size)
103
+ raise e
104
+ arange = np.arange(len(samples))
105
+ np.random.seed(seed)
106
+ test_idx = np.random.choice(arange, size=test_size, replace=False)
107
+ train_idx = np.setdiff1d(arange, test_idx)
108
+ return samples[train_idx], samples[test_idx]
109
+
110
+
111
+ def download_scrolls_metric():
112
+ """
113
+ Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset
114
+ """
115
+ scrolls_metric_path = hf_hub_download(
116
+ repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset"
117
+ )
118
+ updated_scrolls_metric_path = (
119
+ os.path.dirname(scrolls_metric_path) +
120
+ os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
121
+ )
122
+ shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
123
+ return updated_scrolls_metric_path
src/finetune.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finetuning functions to do post-distillation
3
+ """
4
+ from os.path import join
5
+ from omegaconf import OmegaConf
6
+
7
+ import torch
8
+ from torch.nn import Module
9
+
10
+ from src.utils.setup import update_config_from_args
11
+ from src.dataloaders import load_data
12
+ from src.trainer import get_trainer, get_optimizer, get_scheduler
13
+
14
+
15
+ def prepare_finetune_configs(args, model_config: dict,
16
+ finetune_config_name: str = None,
17
+ finetune_checkpoint_name: str = None,
18
+ config_dir='./configs/experiment'):
19
+ """
20
+ Prepare finetuning configs
21
+ """
22
+ # Load finetuning config
23
+ finetune_config = (finetune_config_name if finetune_config_name is not None else
24
+ finetune_checkpoint_name.split('-f=')[-1].split('-')[0])
25
+ finetune_config_path = join(config_dir, f'{finetune_config}.yaml')
26
+ finetune_config = OmegaConf.load(finetune_config_path)
27
+ finetune_config = update_config_from_args(finetune_config, args,
28
+ ignore_args=['lr', 'weight_decay'])
29
+ # Update data tokenizer to match model
30
+ if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None:
31
+ for k in ['pretrained_model_name_or_path', 'cache_dir']:
32
+ finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k]
33
+ # Set finetuning args
34
+ for arg, argv in finetune_config.trainer.items():
35
+ if arg != 'name':
36
+ setattr(args, arg, argv)
37
+ for _config in ['dataloader', 'optimizer', 'lr_scheduler']:
38
+ setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config)))
39
+ return finetune_config, args
40
+
41
+
42
+ def get_finetuner(model: Module, finetune_config: dict, device: torch.device,
43
+ args: any, wandb: any, initial_eval: bool = False):
44
+ """
45
+ Initialize finetuning trainer
46
+ """
47
+ model.to(device) # if using a fused optimizer
48
+ model.train()
49
+
50
+ # Initialize optimizer and scheduler
51
+ optimizer = get_optimizer(model=model, **finetune_config.optimizer)
52
+ scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler)
53
+
54
+ dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader)
55
+ train_loader = dataloaders[finetune_config.trainer.train_split]
56
+ eval_loader = dataloaders[finetune_config.trainer.val_split]
57
+
58
+ OurTrainer = get_trainer(finetune_config.trainer.name)
59
+ trainer = OurTrainer(model=model,
60
+ args=args,
61
+ train_loader=train_loader,
62
+ eval_loader=eval_loader,
63
+ optimizer_and_scheduler=(optimizer, scheduler),
64
+ device=device,
65
+ wandb=wandb,
66
+ checkpoint_suffix='_ft',
67
+ **finetune_config.trainer)
68
+ return trainer
src/model/__init__.py ADDED
File without changes
src/model/convert_model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attention conversion helpers
3
+ """
4
+ from functools import partial
5
+ from tqdm import tqdm
6
+ import torch.nn as nn
7
+
8
+
9
+ def convert_attention(model: nn.Module,
10
+ attention_config: dict,
11
+ train_attention: bool = False,
12
+ remove_base_attn: bool = True,):
13
+ """
14
+ Call to convert all attention layers
15
+ """
16
+ softmax_attns = []
17
+ if 'softmax_attentions' in attention_config:
18
+ softmax_attns = attention_config['softmax_attentions']
19
+ if attention_config.attention_type != 'softmax':
20
+ layers = traverse_layers(model)
21
+ for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')):
22
+ if layer_idx not in softmax_attns:
23
+ layer.self_attn = convert_llama_attention(
24
+ layer, attention_config, layers, train_attention, remove_base_attn,
25
+ )
26
+ layer.self_attn.converted = True
27
+ else: # Freeze any preserved softmax attention layers
28
+ for p in layer.parameters():
29
+ p.requires_grad = False
30
+ else:
31
+ print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions')
32
+ return model
33
+
34
+
35
+ def toggle_attention(llama_model: nn.Module, train: bool = False):
36
+ """
37
+ Make attentions trainable if train is True
38
+ -> Set train_attention = False when finetuning
39
+ """
40
+ for layer in traverse_layers(llama_model):
41
+ layer.self_attn.train_attention = train
42
+ return llama_model
43
+
44
+
45
+ def remove_base_attention(llama_model: nn.Module):
46
+ """
47
+ Remove teacher attention after distillation (if we keep it)
48
+ """
49
+ for layer in traverse_layers(llama_model):
50
+ if getattr(layer.self_attn, 'base_attn', False):
51
+ del layer.self_attn.base_attn
52
+ return llama_model
53
+
54
+
55
+ def traverse_layers(model: nn.Module, verbose: bool = False):
56
+ """
57
+ Return list of model layers
58
+ """
59
+ try:
60
+ layers = model.model.layers
61
+ if verbose:
62
+ print('-> Loading from model.model.layers')
63
+ except AttributeError as e: # if base model
64
+ if verbose:
65
+ print(e)
66
+ try:
67
+ layers = model.layers
68
+ if verbose:
69
+ print('-> Loading from model.layers')
70
+ except AttributeError as e1: # If we make a PEFT model
71
+ if verbose:
72
+ print(e1)
73
+ layers = model.base_model.model.model.layers
74
+ if verbose:
75
+ print('-> Loading from model.base_model.model.model.layers')
76
+ return layers
77
+
78
+
79
+ def convert_llama_attention(layer: nn.Module,
80
+ attention_config: dict,
81
+ layers: list[nn.Module], # list of layers
82
+ train_attention: bool = False,
83
+ remove_base_attn: bool = True):
84
+ """
85
+ Converts a single layer's attention layer as specified by attention_config
86
+ """
87
+ return get_attention(**attention_config)(
88
+ base_attn=layer.self_attn,
89
+ layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
90
+ max_layer_idx=len(layers) - 1,
91
+ train_attention=train_attention,
92
+ remove_base_attn=remove_base_attn,
93
+ )
94
+
95
+
96
+ def get_attention(attention_type: str, **kwargs: any):
97
+ """
98
+ Get the linear attention class; either purely linear or linear with sliding window
99
+ -> 'linear' == 'lolcats_llama'
100
+ -> 'linear and sliding_window' == 'lolcats_llama_window_*'
101
+ """
102
+ kwargs['attention_type'] = attention_type
103
+
104
+ if attention_type == 'lolcats_llama':
105
+ from .linear_attention import LolcatsLinearAttention
106
+ return partial(LolcatsLinearAttention, **kwargs)
107
+
108
+ elif attention_type == 'lolcats_llama_window_tk':
109
+ from .linear_attention import LolcatsTKWindowAttention
110
+ return partial(LolcatsTKWindowAttention, **kwargs)
111
+
112
+ elif attention_type == 'lolcats_llama_window_sw':
113
+ from .linear_attention import LolcatsSlidingWindowAttention
114
+ return partial(LolcatsSlidingWindowAttention, **kwargs)
115
+
116
+ elif attention_type == 'lolcats_llama_window_sw_linear':
117
+ from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
118
+ return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
119
+
120
+ ## Experimental chunked linear attentions below
121
+ elif attention_type == 'lolcats_long_llama_window_tk':
122
+ from .linear_attention import LolcatsTKWindowLongAttention
123
+ return partial(LolcatsTKWindowLongAttention, **kwargs)
124
+
125
+ elif attention_type == 'lolcats_long_llama_window_sw':
126
+ from .linear_attention import LolcatsSlidingWindowLongAttention
127
+ return partial(LolcatsSlidingWindowLongAttention, **kwargs)
128
+
129
+ ## TK generation build (requires Thunderkittens)
130
+ elif attention_type == 'lolcats_llama_window_tk_gen':
131
+ from .linear_attention import LolcatsWindowAttentionTKGen
132
+ return partial(LolcatsWindowAttentionTKGen, **kwargs)
133
+
134
+ else:
135
+ print(f'-> attention_type {attention_type} not handled... returning None')
136
+ return None
137
+
138
+
139
+ def get_attention_cache(attention_type: str, past_key_values: any = None):
140
+ """
141
+ Determine how we store past keys and values when generating
142
+ """
143
+ if attention_type is None:
144
+ return past_key_values
145
+
146
+ # print(f'Returning attention cache based on attention_type == {attention_type}')
147
+ elif 'lolcats_llama_window_tk_gen' in attention_type:
148
+ from .linear_attention import LinearAttentionTKWindowGenerationCache
149
+ return LinearAttentionTKWindowGenerationCache()
150
+
151
+ elif 'llama_window_tk' in attention_type:
152
+ from .linear_attention import LinearAttentionTKWindowCache
153
+ return LinearAttentionTKWindowCache()
154
+
155
+ elif 'llama_window_sw' in attention_type:
156
+ from .linear_attention import LinearAttentionSlidingWindowCache
157
+ return LinearAttentionSlidingWindowCache()
158
+
159
+ elif 'llama_window_sw_linear' in attention_type:
160
+ from .linear_attention import LinearAttentionSlidingWindowCache
161
+ return LinearAttentionSlidingWindowCache()
162
+
163
+ ## TK generation build (requires Thunderkittens)
164
+ elif attention_type == 'lolcats_llama_window_tk_gen':
165
+ from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache
166
+ return LinearAttentionTKWindowGenerationCache()
167
+
168
+ elif 'softmax' in attention_type:
169
+ return past_key_values
170
+
171
+ else:
172
+ from .linear_attention import LinearAttentionState
173
+ return LinearAttentionState()
src/model/feature_map.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learnable linear attention feature map classes and functions
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def init_feature_map(name: str, mlp: nn.Module, **kwargs: dict):
10
+ """
11
+ Initialize feature map final activation for linear attention
12
+ """
13
+ return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
14
+
15
+
16
+ def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
17
+ """
18
+ Initialize feature map final activation for linear attention
19
+ """
20
+ if name == 'softmax_dim' and fullspace:
21
+ return SoftmaxDim(**kwargs)
22
+ elif name == 'softmax_dim' and not fullspace:
23
+ return SoftmaxDimHalfspace(**kwargs)
24
+ elif name == 'exp_dim' and fullspace:
25
+ return Exp(**kwargs)
26
+ elif name == 'exp_dim' and not fullspace:
27
+ return ExpHalfspace(**kwargs)
28
+ elif name == 'pos_elu':
29
+ return PosELU(**kwargs)
30
+ elif name == 'relu':
31
+ return ReLU(**kwargs)
32
+
33
+ else:
34
+ raise NotImplementedError
35
+
36
+
37
+ def init_learned_kernel(name: str, **kwargs: any):
38
+ """
39
+ Initialize feature map MLP for linear attention
40
+ """
41
+ if name == 'untied_head_einsum':
42
+ return FeatureMapMLP(**kwargs)
43
+ elif name == 'untied_head_adapter':
44
+ return FeatureMapAdapter(**kwargs)
45
+ else:
46
+ raise NotImplementedError
47
+
48
+
49
+ class FeatureMap(nn.Module):
50
+ """
51
+ Final 'activation' of feature map. Can probably be combined with
52
+ `FeatureMapMLP` below
53
+
54
+ Full feature map is like f(xW + b)
55
+ -> This is the `f` part
56
+ """
57
+ def __init__(self,
58
+ activation_name: str,
59
+ head_dim_idx: int = -1,
60
+ eps: float = 1e-12,
61
+ mlp: nn.Module = None,
62
+ fullspace: bool = True,):
63
+ super().__init__()
64
+ self.head_dim_idx = head_dim_idx
65
+ self.eps = eps
66
+ self.mlp = mlp if mlp is not None else nn.Identity()
67
+ self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
68
+
69
+ def forward(self, x: torch.Tensor, *mlp_args: any, **mlp_kwargs: any):
70
+ """
71
+ Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
72
+ """
73
+ return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
74
+
75
+ def q_map(self, *args: any, **kwargs: any):
76
+ """
77
+ Use for inference in case q and k feature maps differ
78
+ """
79
+ return self.forward(*args, **kwargs)
80
+
81
+ def k_map(self, *args: any, **kwargs: any):
82
+ """
83
+ Use for inference in case q and k feature maps differ
84
+ """
85
+ return self.forward(*args, **kwargs)
86
+
87
+
88
+ # -----------------------
89
+ # Feature map activations
90
+ # -----------------------
91
+ class FeatureMapAct(nn.Module):
92
+ """
93
+ Base class for feature map activations
94
+ """
95
+ def __init__(self, eps: float = 1e-12):
96
+ super().__init__()
97
+ self.eps = eps
98
+
99
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
100
+ """
101
+ x.shape is (batch_size, n_heads, seq_len, head_dim)
102
+ """
103
+ return x
104
+
105
+
106
+ class PosELU(FeatureMapAct):
107
+ """
108
+ 1 + ELU activation as in https://arxiv.org/abs/2006.16236
109
+ """
110
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
111
+ return (1 + F.elu(x)).clamp(min=self.eps)
112
+
113
+
114
+ class ReLU(FeatureMapAct):
115
+ """
116
+ ReLU activation as in https://arxiv.org/abs/2103.13076
117
+ """
118
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
119
+ return F.relu(x).clamp(min=self.eps)
120
+
121
+
122
+ class SoftmaxDim(FeatureMapAct):
123
+ """
124
+ Softmax activation as in https://arxiv.org/abs/2402.04347
125
+ """
126
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
127
+ return torch.cat([
128
+ torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)
129
+ ], dim=-1).clamp(min=self.eps)
130
+
131
+
132
+ class SoftmaxDimHalfspace(FeatureMapAct):
133
+ """
134
+ Softmax activation as in https://arxiv.org/abs/2402.04347
135
+ """
136
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
137
+ return torch.softmax(x, dim=-1).clamp(min=self.eps)
138
+
139
+
140
+ class Exp(FeatureMapAct):
141
+ """
142
+ Exp activation as in https://arxiv.org/abs/2402.04347
143
+ """
144
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
145
+ x_max = torch.amax(x, dim=-1, keepdim=True)
146
+ x_min = torch.amin(x, dim=-1, keepdim=True)
147
+ return torch.cat([
148
+ torch.exp(x - x_max), torch.exp(-x + x_min)
149
+ ], dim=-1).clamp(min=self.eps)
150
+
151
+
152
+ class ExpHalfspace(FeatureMapAct):
153
+ """
154
+ Exp activation as in https://arxiv.org/abs/2402.04347
155
+ """
156
+ def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
157
+ x_max = torch.amax(x, dim=-1, keepdim=True)
158
+ return torch.exp(x - x_max).clamp(min=self.eps)
159
+
160
+
161
+ # ----------------
162
+ # Feature map MLPs
163
+ # ----------------
164
+
165
+ class FeatureMapMLP(nn.Module):
166
+ """
167
+ Learnable MLP in feature map.
168
+
169
+ Full feature map is like f(xW + b)
170
+ -> This is the `W` and (optional) `b` part
171
+ """
172
+ def __init__(self,
173
+ num_heads: int,
174
+ head_dim: int, # input dim
175
+ feature_dim: int, # output dim
176
+ dtype: torch.dtype,
177
+ device: torch.device,
178
+ skip_connection: bool = False,
179
+ bias: bool = False,
180
+ zero_init: bool = False,
181
+ normal_init: bool = False,):
182
+ super().__init__()
183
+ self.num_heads = num_heads
184
+ self.head_dim = head_dim
185
+ self.feature_dim = feature_dim
186
+ self.dtype = dtype
187
+ self.device = device
188
+ self.skip_connection = skip_connection
189
+ self.bias = bias
190
+ self.zero_init = zero_init
191
+ self.normal_init = normal_init
192
+ self.init_weights_()
193
+
194
+ if self.zero_init: # Zero-out weights or set as identity post-initialization
195
+ self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
196
+
197
+ if self.normal_init:
198
+ with torch.no_grad():
199
+ nn.init.normal_(self.layer)
200
+
201
+ if self.skip_connection:
202
+ assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}'
203
+ assert self.head_dim == self.feature_dim, assertion_fail
204
+
205
+ def init_weights_(self):
206
+ """
207
+ Initialize (W)eights and (b)iases
208
+ """
209
+ self.layer = nn.Parameter(torch.zeros(
210
+ (self.num_heads, self.head_dim, self.feature_dim),
211
+ dtype=self.dtype, device=self.device,
212
+ ))
213
+ nn.init.kaiming_uniform_(self.layer)
214
+
215
+ if self.bias:
216
+ self.bias = nn.Parameter(torch.zeros(
217
+ (1, self.num_heads, 1, 1), # self.feature_dim),
218
+ dtype=self.dtype, device=self.device,
219
+ ))
220
+ nn.init.kaiming_uniform_(self.bias)
221
+ else:
222
+ self.bias = 0. # hack
223
+
224
+ def zero_init_with_skip_(self):
225
+ """
226
+ Initialize weights to zero matrix if skip connection
227
+ """
228
+ with torch.no_grad():
229
+ nn.init.zeros_(self.layer)
230
+
231
+ def zero_init_(self):
232
+ """
233
+ Initialize weights to identity matrix if no skip connection
234
+ """
235
+ with torch.no_grad():
236
+ for i in range(self.layer.shape[0]):
237
+ try:
238
+ nn.init.eye_(self.layer[i])
239
+ except RuntimeError:
240
+ with torch.no_grad():
241
+ dtype = self.layer[i].dtype
242
+ weight = torch.eye(*self.layer[i].shape,
243
+ requires_grad=self.layer[i].requires_grad,
244
+ device=self.layer[i].device)
245
+ self.layer[i] = weight.to(dtype=dtype)
246
+
247
+ def forward(self, x: torch.Tensor):
248
+ """
249
+ Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
250
+ """
251
+ _x = torch.einsum('hdf,bhld->bhlf', self.layer, x) + self.bias
252
+ return x + _x if self.skip_connection else _x
253
+
254
+
255
+ class FeatureMapAdapter(FeatureMapMLP):
256
+ """
257
+ Learnable Feature map with bottleneck adapter
258
+ as in https://arxiv.org/abs/1902.00751
259
+
260
+ We don't use but could be fun to try
261
+ """
262
+ def __init__(self, hidden_dim: int, *args, **kwargs):
263
+ kwargs['skip_connection'] = True
264
+ kwargs['bias'] = True
265
+ kwargs['zero_init'] = True
266
+ self.hidden_dim = hidden_dim
267
+ super().__init__(*args, **kwargs)
268
+
269
+ def init_weights_(self):
270
+ """
271
+ Initialize (W)eights and (b)iases
272
+ """
273
+ kwargs = {'dtype': self.dtype, 'device': self.device}
274
+ self.layer0 = nn.Parameter(
275
+ torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
276
+ )
277
+ self.layer1 = nn.Parameter(
278
+ torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
279
+ )
280
+ nn.init.kaiming_uniform_(self.layer0)
281
+ nn.init.kaiming_uniform_(self.layer1)
282
+
283
+ self.bias0 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs))
284
+ self.bias1 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs))
285
+ nn.init.kaiming_uniform_(self.bias0)
286
+ nn.init.kaiming_uniform_(self.bias1)
287
+
288
+ def zero_init_with_skip_(self):
289
+ with torch.no_grad():
290
+ nn.init.zeros_(self.layer0)
291
+ nn.init.zeros_(self.layer1)
292
+ nn.init.zeros_(self.bias0)
293
+ nn.init.zeros_(self.bias1)
294
+
295
+ def zero_init_(self):
296
+ assert NotImplementedError
297
+
298
+ def forward(self, x: torch.Tensor):
299
+ """
300
+ Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
301
+ -> Down-project, apply nonlinearity, up-project; add skip connection
302
+ """
303
+ _x = torch.einsum('hde,bhld->bhle', self.layer0, x) + self.bias0
304
+ _x = F.relu(_x)
305
+ _x = torch.einsum('hef,bhle->bhlf', self.layer1, _x) + self.bias1
306
+ return x + _x if self.skip_connection else _x
src/model/linear_attention/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear and linear attention + sliding window classes
3
+ """
4
+ from .linear_attention import (
5
+ LolcatsLinearAttention, LinearAttentionState
6
+ )
7
+ from .linear_window_attention_tk import (
8
+ LolcatsTKWindowAttention, LinearAttentionTKWindowCache
9
+ )
10
+ from .linear_window_attention_sw import (
11
+ LolcatsSlidingWindowAttention, LinearAttentionSlidingWindowCache
12
+ )
13
+ # Experimental chunk linear attentions
14
+ from .linear_window_attention_tk_long import (
15
+ LolcatsTKWindowLongAttention,
16
+ )
17
+ from .linear_window_attention_sw_long import (
18
+ LolcatsSlidingWindowLongAttention,
19
+ )
20
+ from .linear_window_attention_tk_gen import (
21
+ LolcatsWindowAttentionTKGen,
22
+ LinearAttentionTKWindowGenerationCache
23
+ )
src/model/linear_attention/linear_attention.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear attention classes
3
+ """
4
+ from typing import List, Tuple, Optional
5
+ import copy
6
+ import torch
7
+ import torch.nn as nn
8
+ from omegaconf import OmegaConf, DictConfig
9
+
10
+ from transformers.cache_utils import Cache # starting at Transformers v4.36
11
+
12
+ # Causal linear attention dot product CUDA kernel from fast-transformers
13
+ try:
14
+ from csrc import causal_dot_product as fast_causal_dot_product
15
+ except ImportError:
16
+ fast_causal_dot_product = None
17
+
18
+ from src.model.feature_map import init_feature_map, init_learned_kernel
19
+ from src.model.rotary import get_rotary_embeddings, apply_rotary_pos_emb
20
+ from .utils import repeat_kv
21
+
22
+
23
+ # -------------------
24
+ # Attention functions
25
+ # -------------------
26
+
27
+ def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
28
+ """
29
+ Causal linear attention dot product
30
+ - If available, use CUDA kernel from fast-transformers
31
+ """
32
+ if fast_causal_dot_product is None:
33
+ kv = torch.einsum('bhlf,bhld->bhlfd', k, v)
34
+ return torch.einsum('bhlf,bhlfd->bhld', q, kv.cumsum(dim=2))
35
+ return fast_causal_dot_product(q, k, v)
36
+
37
+ def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
38
+ fp32_attention: bool = False, eps: float = 1e-12,
39
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40
+ """
41
+ Compute linear attention with CUDA kernel implementation from fast-transformers
42
+ - https://github.com/idiap/fast-transformers
43
+ - Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
44
+ v is shape (b, h, l, head_dim)
45
+ """
46
+ dtype = q.dtype
47
+ # Causal mask already applied
48
+ y = causal_dot_product(q.contiguous().to(dtype=torch.float32),
49
+ k.contiguous().to(dtype=torch.float32),
50
+ v.contiguous().to(dtype=torch.float32))
51
+ if fp32_attention:
52
+ y = (y / (torch.einsum(
53
+ "bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)
54
+ ) + eps)[..., None]).to(dtype=dtype)
55
+ else:
56
+ y = y.to(dtype=dtype)
57
+ k = k.float().cumsum(dim=2).to(dtype=dtype)
58
+ y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
59
+ return y, None, None
60
+
61
+
62
+ def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None,
63
+ causal: bool = True, fp32_attention: bool = True,
64
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
65
+ """
66
+ Standard softmax attention; only compute outputs if v is not None
67
+ -> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
68
+ """
69
+ y = None
70
+ a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5)
71
+ if causal: # Apply causal mask
72
+ m, n = a.shape[-2:]
73
+ causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
74
+ a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
75
+ if fp32_attention:
76
+ a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
77
+ else:
78
+ a = torch.softmax(a, dim=-1)
79
+ if v is not None:
80
+ y = torch.einsum('bhmn,bhnd->bhmd', a, v)
81
+ return y, a, None
82
+
83
+
84
+ def quadratic_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None,
85
+ causal: bool = True, fp32_attention: bool = False, eps: float = 1e-12,
86
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
87
+ """
88
+ Compute attention with feature maps by instantiating L x L matrix of attention weights
89
+ -> Use for attention distillation
90
+ -> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
91
+ """
92
+ y = None
93
+ dtype = q.dtype
94
+ if fp32_attention:
95
+ q, k = q.float(), k.float()
96
+ a = torch.einsum('bhmd,bhnd->bhmn', q, k) # note we don't scale, tho we could
97
+ if causal: # Apply causal mask
98
+ m, n = a.shape[-2:]
99
+ causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
100
+ a = a.masked_fill(causal_mask, 0)
101
+ # Normalize to compute attention
102
+ a = a / (a.sum(dim=-1, keepdim=True) + eps)
103
+ a = a.to(dtype=dtype) if fp32_attention else a
104
+ if torch.isnan(a).sum() > 0:
105
+ breakpoint()
106
+ if v is not None:
107
+ y = torch.einsum('bhmn,bhnd->bhmd', a, v)
108
+ return y, a, None
109
+
110
+
111
+ # ---------------------
112
+ # Attention layer class
113
+ # ---------------------
114
+
115
+ class LolcatsLinearAttention(nn.Module):
116
+ """
117
+ LoLCATs attention implementation initialized from a
118
+ `LlamaAttention` or `MistralAttention` object (base_attn)
119
+
120
+ Most of the arguments are directly tied to argparse args
121
+ - For now we don't support padding.
122
+ """
123
+ def __init__(self,
124
+ base_attn: nn.Module, # like LlamaAttention
125
+ feature_map: str,
126
+ feature_map_kwargs: dict,
127
+ layer_idx: Optional[int] = None,
128
+ max_layer_idx: Optional[int] = None,
129
+ learned_kernel: Optional[str] = None,
130
+ learned_kernel_kwargs: Optional[dict] = None,
131
+ tie_qk_kernels: Optional[bool] = False,
132
+ rotary_config: Optional[dict] = None,
133
+ train_attention: Optional[bool] = False,
134
+ remove_base_attn: Optional[bool] = True,
135
+ attention_type: Optional[str] = 'lolcats_llama',
136
+ mask_value: int = 0,
137
+ eps: float = 1e-12,
138
+ fp32_attention: bool = False,
139
+ track_state_grads: bool = False,
140
+ rank: Optional[int] = 0,
141
+ **kwargs: any) -> None:
142
+ super().__init__()
143
+ self.base_config = getattr(base_attn, 'config', None)
144
+ if self.base_config is not None:
145
+ self.base_config = self.base_config.to_dict()
146
+ self.attention_type = attention_type
147
+ self.mask_value = mask_value
148
+ self.eps = eps
149
+ self.layer_idx = (layer_idx if layer_idx is not None else base_attn.layer_idx)
150
+ self.max_layer_idx = max_layer_idx
151
+ self.tie_qk_kernels = tie_qk_kernels
152
+ self.train_attention = train_attention
153
+ self.base_inference = False
154
+ self.fp32_attention = fp32_attention
155
+ self.track_state_grads = track_state_grads
156
+ if rank == 0: # multi-gpu
157
+ if fp32_attention and layer_idx == 0:
158
+ print(f'-> fp32_attention is {fp32_attention}')
159
+ if layer_idx == 0 and feature_map_kwargs is not None:
160
+ for k, v in feature_map_kwargs.items():
161
+ print(f'-> {k}: {v}')
162
+ if layer_idx == 0 and learned_kernel_kwargs is not None:
163
+ for k, v in learned_kernel_kwargs.items():
164
+ print(f'-> {k}: {v}')
165
+
166
+ self.remove_base_attn = remove_base_attn
167
+
168
+ # Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
169
+ self.rotary_config = rotary_config
170
+ if isinstance(self.rotary_config, DictConfig): # ensure dict
171
+ self.rotary_config = OmegaConf.to_container(self.rotary_config)
172
+
173
+ self.rotary_emb = None
174
+ if self.base_config is not None and self.rotary_config is None:
175
+ self.rotary_emb = base_attn.rotary_emb
176
+
177
+ self.init_weights_(base_attn, remove_base_attn)
178
+ self.init_feature_map_(feature_map, feature_map_kwargs,
179
+ learned_kernel, learned_kernel_kwargs)
180
+
181
+ def init_feature_map_(self,
182
+ feature_map: str,
183
+ feature_map_kwargs: dict,
184
+ learned_kernel: str = None,
185
+ learned_kernel_kwargs: dict = None):
186
+ """
187
+ Initialize MLP-based feature map
188
+ """
189
+ self.fmap_gqa = False # Turn True if specified below
190
+ if learned_kernel is not None:
191
+ # Ensure dict
192
+ learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
193
+ learned_kernel_kwargs['num_heads'] = self.num_heads
194
+ learned_kernel_kwargs['head_dim'] = self.head_dim
195
+ learned_kernel_kwargs['dtype'] = self.q_proj.weight.dtype
196
+ learned_kernel_kwargs['device'] = self.q_proj.weight.device
197
+ # Create MLP
198
+ mlp_learned_kernel = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)
199
+ # Add "activation"; see src.models.feature_map.py
200
+ self.feature_map_q = init_feature_map(name=feature_map,
201
+ mlp=mlp_learned_kernel,
202
+ **feature_map_kwargs)
203
+ if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
204
+ self.feature_map_k = self.feature_map_q
205
+ else:
206
+ self.feature_map_k = copy.deepcopy(self.feature_map_q)
207
+
208
+ def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
209
+ """
210
+ Initialize module layers, weights, positional dependencies, etc.
211
+ from original softmax attention layer (base_attn)
212
+ """
213
+ # Make other attributes accessible
214
+ self.attention_dropout = 0 # We don't use dropout
215
+ self.hidden_size = base_attn.hidden_size
216
+ self.num_heads = base_attn.num_heads
217
+ self.head_dim = base_attn.head_dim
218
+ self.num_key_value_heads = base_attn.num_key_value_heads
219
+ self.num_key_value_groups = base_attn.num_key_value_groups
220
+
221
+ self.q_shape = [self.num_heads, self.head_dim]
222
+ self.k_shape = [self.num_key_value_heads, self.head_dim]
223
+ self.v_shape = [self.num_key_value_heads, self.head_dim]
224
+ device = base_attn.q_proj.weight.device
225
+ # Rotary embeddings
226
+ if self.rotary_emb is None:
227
+ self.max_position_embeddings = base_attn.max_position_embeddings
228
+ scaling_factor = getattr(base_attn.rotary_emb, 'scaling_factor', 1.)
229
+ if self.rotary_config is None:
230
+ self.rotary_emb = get_rotary_embeddings(
231
+ rope_scaling_type=None,
232
+ head_dim=self.head_dim,
233
+ max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
234
+ rope_theta=base_attn.rotary_emb.base,
235
+ rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
236
+ device=device,
237
+ )
238
+ else:
239
+ if 'device' not in self.rotary_config:
240
+ self.rotary_config['device'] = device
241
+ self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
242
+
243
+ # Copy original model projection layers
244
+ self.q_proj = base_attn.q_proj
245
+ self.k_proj = base_attn.k_proj
246
+ self.v_proj = base_attn.v_proj
247
+ self.o_proj = base_attn.o_proj
248
+ try: # If wanting to use FA2 for ground-truth inference
249
+ self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
250
+ except AttributeError:
251
+ pass
252
+
253
+ if self.remove_base_attn or remove_base_attn:
254
+ del base_attn # We don't need to keep these around
255
+ else:
256
+ self.base_attn = base_attn # For some training runs helpful to just call
257
+
258
+ def process_qkv(self,
259
+ hidden_states: torch.Tensor,
260
+ attention_mask: Optional[torch.Tensor] = None,
261
+ position_ids: Optional[torch.LongTensor] = None,
262
+ past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach
263
+ """
264
+ Compute queries, keys, and values
265
+ """
266
+ b, l, _ = hidden_states.size()
267
+ q = self.q_proj(hidden_states)
268
+ k = self.k_proj(hidden_states)
269
+ v = self.v_proj(hidden_states)
270
+ kv_seq_len = k.shape[-2]
271
+
272
+ # Shape is (batch_size, seq_len, num_heads, head_dim)
273
+ q = q.view(b, l, *self.q_shape).transpose(1, 2)
274
+ k = k.view(b, l, *self.k_shape).transpose(1, 2)
275
+ v = v.view(b, l, *self.v_shape).transpose(1, 2)
276
+
277
+ if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating
278
+ past_key_value.window_size = getattr(self, 'decode_window_size', None) # self.decode_window_size
279
+ if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object
280
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
281
+ else:
282
+ kv_seq_len += past_key_value[0].shape[-2]
283
+
284
+ # Apply rotary embeddings and repeat for GQA
285
+ if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
286
+ kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
287
+ try: # As in Transformers v4.36
288
+ cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
289
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
290
+ except TypeError: # As in Transformers v4.39+
291
+ cos, sin = self.rotary_emb(v, position_ids)
292
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
293
+
294
+ k = repeat_kv(k, self.num_key_value_groups)
295
+ v = repeat_kv(v, self.num_key_value_groups)
296
+ return q, k, v, kv_seq_len
297
+
298
+ def forward(self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ position_ids: Optional[torch.LongTensor] = None,
302
+ past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach
303
+ output_attentions: bool = False,
304
+ use_cache: bool = False,
305
+ **kwargs,
306
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
307
+ """
308
+ Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
309
+ - Consistent with HuggingFace Transformers for easy use with their pretrained models
310
+ """
311
+ b, l, _ = hidden_states.size()
312
+ q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
313
+ position_ids, past_key_value)
314
+ if self.base_inference:
315
+ with torch.no_grad():
316
+ # 1. Compute "ground-truth" attention output and weights
317
+ y_true, _, _ = softmax_attention(q, k, v, causal=True)
318
+ y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
319
+ y_true = self.o_proj(y_true)
320
+ attn_weights = (None, None)
321
+
322
+ elif self.train_attention: # Distilling / learning attentions
323
+ # Note for now we assume no padding when distilling; attention masks only enforce causality
324
+ assert output_attentions is True, f'When training feature maps, output_attentions should be True but is {output_attentions}'
325
+ with torch.no_grad():
326
+ # 1. Compute "ground-truth" attention output and weights
327
+ _y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
328
+ y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
329
+ y_true = self.o_proj(y_true)
330
+
331
+ # 2. Compute "predicted" attention (just weights)
332
+ q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
333
+ y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
334
+ attn_weights = ((attn_pred, attn_true), (y_pred, _y_true)) # Save both attention weights so we can supervise.
335
+
336
+ else: # Finetuning
337
+ q, k = self.feature_map_q(q), self.feature_map_k(k)
338
+ # Apply prefill mask
339
+ if attention_mask is not None and q.shape[2] > 1:
340
+ if len(attention_mask.shape) == 4:
341
+ lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1
342
+ else:
343
+ lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
344
+ k = k.masked_fill(~lin_attn_mask, 0)
345
+
346
+ if past_key_value is not None: # Initialize states
347
+ if len(past_key_value.kv_states) == self.layer_idx:
348
+ b, h, _, f = k.shape
349
+ past_key_value.kv_states.append(
350
+ torch.zeros(b, h, f, self.head_dim, dtype=q.dtype, device=q.device)
351
+ )
352
+ past_key_value.k_states.append(
353
+ torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
354
+ )
355
+ # Generating
356
+ if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
357
+ assert use_cache is True
358
+ kv_state, k_state = past_key_value.update(k, v, self.layer_idx,
359
+ accumulate_in_fp32=self.fp32_attention)
360
+ if self.fp32_attention:
361
+ q = q.float()
362
+ y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state.float()) /
363
+ torch.einsum('bhlf,bhlf->bhl', q, k_state.float())[..., None]).to(dtype=k.dtype)
364
+ else:
365
+ y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state) /
366
+ torch.einsum('bhlf,bhlf->bhl', q, k_state)[..., None])
367
+ else:
368
+ kv_state = past_key_value.kv_states[self.layer_idx]
369
+ k_state = past_key_value.k_states[self.layer_idx]
370
+ y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) # Ordinarily the states are ignored
371
+ past_key_value.update(k.detach(), v.detach(), self.layer_idx,
372
+ accumulate_in_fp32=self.fp32_attention)
373
+ # doing some unnecessary recomputation here
374
+ else:
375
+ y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
376
+
377
+ # Concatenate heads and apply output projection
378
+ y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
379
+ y_true = self.o_proj(y_true)
380
+ attn_weights = None
381
+
382
+ return y_true, attn_weights, past_key_value
383
+
384
+
385
+ class LinearAttentionState(Cache):
386
+ """
387
+ Handle the KV and K states for linear attention
388
+ - Adopts HF Transformers `past_key_values` convention
389
+ - Inherits from `Cache` class
390
+ - Modified from transformers.cache_utils.DynamicCache (v4.36)
391
+ """
392
+ def __init__(self) -> None:
393
+ self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
394
+ self._seen_tokens_by_layer: List[int] = []
395
+ self.kv_states: List[torch.Tensor] = []
396
+ self.k_states: List[torch.Tensor] = []
397
+
398
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
399
+ """
400
+ Returns the sequence length of the cached states. A layer index can be optionally passed.
401
+ """
402
+ if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
403
+ self._seen_tokens_by_layer.append(0)
404
+ return self._seen_tokens_by_layer[layer_idx]
405
+
406
+ def get_max_length(self) -> Optional[int]:
407
+ """
408
+ Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
409
+ """
410
+ return None
411
+
412
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
413
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
414
+ # Cache without size limit -> all cache is usable
415
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
416
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
417
+ max_length = self.get_max_length()
418
+ previous_seq_length = self.get_seq_length(layer_idx)
419
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
420
+ return max_length - new_seq_length
421
+ return previous_seq_length
422
+
423
+ def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
424
+ layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
425
+ accumulate_in_fp32: bool = True, **kwargs: any,
426
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
427
+
428
+ with torch.no_grad ():
429
+ if layer_idx == 0:
430
+ self._seen_tokens += key_states.shape[-2]
431
+ dtype = key_states.dtype
432
+ if accumulate_in_fp32:
433
+ key_states, value_states = key_states.float(), value_states.float()
434
+
435
+ kv_state = torch.einsum('bhlf,bhld->bhfd', key_states, value_states).detach()
436
+ k_state = key_states.sum(dim=-2, keepdim=True).detach() # b, h, 1, f; note the 1
437
+ # Update the cache
438
+ if len(self.k_states) <= layer_idx: # Initializing kv and k states
439
+ print('if len(self.k_states) <= layer_idx: # Initializing kv and k states')
440
+ self.kv_states.append(kv_state.to(dtype))
441
+ self.k_states.append(k_state.to(dtype))
442
+ else:
443
+ kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
444
+ k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
445
+ self.kv_states[layer_idx] = kv_state
446
+ self.k_states[layer_idx] = k_state
447
+ self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
448
+ return self.kv_states[layer_idx], self.k_states[layer_idx]
449
+
450
+ def to_legacy_cache(self):
451
+ """Hack, but just return self"""
452
+ return self
453
+
454
+ def reorder_cache(self, beam_idx: torch.LongTensor):
455
+ """
456
+ Reorders the cache for beam search, given the selected beam indices.
457
+ -> Copied from transformers/src/transformers/cache_utils.py
458
+ """
459
+ raise NotImplementedError('Reordering cache not implemented for LinearAttentionState')
src/model/linear_attention/linear_window_attention_sw.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Subquadratic attention combining sliding window and linear attentions
3
+ - Using "standard" sliding windows
4
+ - Didactically computes outputs with n^2 attention weights for now
5
+ - Copied + adapted from linear_window_attention_tk.py for single-file reference
6
+
7
+ For each layer:
8
+ - We first compute (softmax) attention over sliding windows
9
+ - We then compute standard linear attention to "fill in" the earlier parts
10
+ - We combine to model the entire sequence
11
+ """
12
+ from typing import List, Tuple, Optional, Callable
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from transformers.cache_utils import Cache
19
+
20
+ from .linear_attention import (
21
+ LolcatsLinearAttention, LinearAttentionState,
22
+ softmax_attention
23
+ )
24
+
25
+ # ----------------------
26
+ # Sliding window helpers
27
+ # ----------------------
28
+ def get_masks(window_size: int, q_len: int, k_len: int,
29
+ device: torch.device) -> tuple[torch.Tensor]:
30
+ """
31
+ Return masks for softmax and linear attention terms
32
+ -> 1 is include, 0 is ignore
33
+ """
34
+ kwargs = {'device': device, 'dtype': int}
35
+ causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len)
36
+ linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size)
37
+ window_mask = causal_mask - linear_mask
38
+ # Return softmax mask (window), linear attention mask
39
+ # -> shapes broadcast over (b, h, q_len, k_len)
40
+ return window_mask[None, None, ...], linear_mask[None, None, ...]
41
+
42
+
43
+ def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
44
+ f_q: torch.Tensor, f_k: torch.Tensor,
45
+ v: torch.Tensor,
46
+ window_factor: torch.Tensor,
47
+ linear_factor: torch.Tensor,
48
+ window_size: int,
49
+ kv_state: torch.Tensor = None,
50
+ k_state: torch.Tensor = None,
51
+ eps: float = 1e-12,
52
+ mask_value: float=-1e8):
53
+ """
54
+ Hybrid attention combining sliding window and linear attentions
55
+ """
56
+
57
+ mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
58
+
59
+ # 1. Sliding window (softmax attention)
60
+ a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
61
+ a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
62
+ # torch.softmax(a_sm, dim=-1), but we account for the max when combining
63
+ a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
64
+ a_sm = window_factor * torch.exp(a_sm - a_sm_max)
65
+ sum_sm = a_sm.sum(dim=-1, keepdim=True)
66
+
67
+ # 2. Under window (linear attention)
68
+ a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
69
+ a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
70
+ sum_ln = a_ln.sum(dim=-1, keepdim=True)
71
+
72
+ # 3. Combine
73
+ a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
74
+ # Allow outputs to also depend on prior kv_state and k_state
75
+ y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
76
+ if kv_state is not None: # Combine with prior kv_state and k_state
77
+ y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
78
+ sum_ln += linear_factor * torch.einsum(
79
+ 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
80
+ y = (y / (sum_sm + sum_ln)).to(q.dtype)
81
+ return y, a # attention weights only for the last chunk
82
+
83
+
84
+ # ---------------------
85
+ # Attention layer class
86
+ # ---------------------
87
+ class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
88
+ """
89
+ Lolcats attention combining sliding window and linear attention
90
+ """
91
+ def __init__(self,
92
+ window_size: int = 64,
93
+ decode_window_size: int = None,
94
+ affine_attention_factors: bool = False,
95
+ init_window_factor: float = 0,
96
+ train_window_factor: bool = True,
97
+ state_grad_enabled: bool = False,
98
+ **kwargs):
99
+ self.window_size = window_size
100
+ self.decode_window_size = (
101
+ decode_window_size if decode_window_size is not None else window_size
102
+ )
103
+ self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
104
+ super().__init__(**kwargs)
105
+ self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_sw'
106
+ # Determine how we compute attentions
107
+ self.quadratic_attention = hybrid_attention_quadratic
108
+ self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_sw'
109
+ # Learnable factor for combining attentions
110
+ self.affine_attention_factors = affine_attention_factors
111
+ device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
112
+ if train_window_factor:
113
+ self.window_factors = nn.Parameter(
114
+ init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
115
+ else:
116
+ self.register_buffer(
117
+ "window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
118
+ )
119
+ # Whether we use original flash attention 2 inference (use during attention transfer)
120
+ self.base_inference = False
121
+ self.state_grad_enabled = state_grad_enabled
122
+
123
+ def forward(self,
124
+ hidden_states: torch.Tensor,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ position_ids: Optional[torch.LongTensor] = None,
127
+ past_key_value: Optional[Cache] = None,
128
+ output_attentions: bool = False,
129
+ use_cache: bool = False,
130
+ **kwargs,
131
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
132
+ """
133
+ Forward pass with the option to compute attention weights multiple ways
134
+ if self.train_attention is True
135
+ -> Consistent with HuggingFace Transformers for easy use with their pretrained models
136
+ """
137
+ b, l, _ = hidden_states.size()
138
+ q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
139
+ position_ids, past_key_value)
140
+ f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
141
+
142
+ if self.train_attention:
143
+ # 1. Compute "ground-truth" attention output and weights
144
+ with torch.no_grad():
145
+ _y_true, a_true = softmax_attention(q, k, v)[:2]
146
+ y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
147
+ y_true = self.o_proj(y_true)
148
+
149
+ # 2. Compute "predicted" attention outputs
150
+ # compute attn weights under sliding window
151
+ window_factors = F.sigmoid(self.window_factors)
152
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
153
+ y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
154
+ window_factors, linear_factors,
155
+ window_size=self.window_size)
156
+ attn_weights = ((a_pred, a_true), (y_pred, _y_true))
157
+ else:
158
+ attn_weights = None
159
+ # attention_mask = None # For now this is always True
160
+ if past_key_value is None: # Regular training
161
+ window_factors = F.sigmoid(self.window_factors)
162
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
163
+ y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
164
+ window_factors, linear_factors,
165
+ window_size=self.window_size)
166
+ attn_weights = a_pred
167
+ else:
168
+ past_key_value.window_size = self.decode_window_size
169
+ if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
170
+ assert use_cache is True
171
+ _kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
172
+ self.feature_map_k,
173
+ dtype=q.dtype)
174
+ k_cache, v_cache, f_kv_state, f_k_state = _kv
175
+
176
+ # Sliding window + linear attention decode
177
+ window_factors = F.sigmoid(self.window_factors)
178
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
179
+
180
+ # Softmax attention terms
181
+ a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
182
+ a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
183
+ a_sm = window_factors * torch.exp(a_sm - a_sm_max)
184
+ sum_sm = a_sm.sum(dim=-1, keepdim=True)
185
+
186
+ # Combine with linear attention terms
187
+ y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
188
+ + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
189
+ sum_ln = linear_factors * torch.einsum(
190
+ 'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
191
+ y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
192
+
193
+ else: # Stateful training
194
+ try:
195
+ kv_state = past_key_value.kv_states[self.layer_idx]
196
+ k_state = past_key_value.k_states[self.layer_idx]
197
+ except IndexError:
198
+ kv_state, k_state = None, None
199
+ window_factors = F.sigmoid(self.window_factors)
200
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
201
+ y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v,
202
+ window_factors, linear_factors,
203
+ window_size=self.window_size,
204
+ kv_state=kv_state,
205
+ k_state=k_state)
206
+ # Save and update KV cache and states
207
+ # past_key_value.update(k, v.detach(), self.layer_idx,
208
+ # fmap_key_states=f_k.detach(),
209
+ # accumulate_in_fp32=True)
210
+ past_key_value.update(k, v, self.layer_idx,
211
+ fmap_key_states=f_k,
212
+ accumulate_in_fp32=True)
213
+ # Concatenate heads and apply output projection
214
+ y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
215
+ y_true = self.o_proj(y_true)
216
+ return y_true, attn_weights, past_key_value
217
+
218
+
219
+ class LinearAttentionSlidingWindowCache(LinearAttentionState):
220
+ """
221
+ Class for `past_key_values`
222
+ -> Alternative to KV cache; here we only maintain a "KV state" and "K state"
223
+ -> Modified from transformers.cache_utils.DynamicCache (v4.36)
224
+ """
225
+ def __init__(self, window_size: int = 64) -> None:
226
+ super().__init__()
227
+ self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
228
+ self._seen_tokens_by_layer: List[int] = []
229
+ self.kv_states: List[torch.Tensor] = []
230
+ self.k_states: List[torch.Tensor] = []
231
+
232
+ # Account for sliding windows
233
+ self.decode_kv_states: List[torch.Tensor] = []
234
+ self.decode_k_states: List[torch.Tensor] = []
235
+ self.k_cache: List[torch.Tensor] = []
236
+ self.v_cache: List[torch.Tensor] = []
237
+ self.window_size = window_size
238
+
239
+ def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
240
+ layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
241
+ accumulate_in_fp32: bool = False,
242
+ fmap_key_states: torch.Tensor = None, # should not be None
243
+ grad_enabled: bool = False,
244
+ **kwargs: any,
245
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
246
+ """
247
+ Update KV, K states; and KV cache during training
248
+ - For decoding, use `self.decode_kv_states` to keep track of KV states
249
+ up to sliding window terms
250
+ - For (chunked) training, use `self.kv_states` to keep track of KV states
251
+ up to end of sequence
252
+ - Likewise for `self.decode_k_states` and `self.k_states`
253
+ """
254
+ with torch.set_grad_enabled(grad_enabled):
255
+ if layer_idx == 0:
256
+ self._seen_tokens += key_states.shape[-2]
257
+
258
+ dtype = key_states.dtype
259
+ if accumulate_in_fp32:
260
+ # key_states = key_states.float()
261
+ fmap_key_states = fmap_key_states.float()
262
+ value_states = value_states.float()
263
+
264
+ # Decoding KV state (KV terms up to last window_size)
265
+ decode_kv_state = torch.einsum(
266
+ 'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
267
+ )
268
+ # KV state
269
+ kv_state = decode_kv_state + torch.einsum(
270
+ 'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
271
+ )
272
+ # shape is b, h, 1, f; note the 1
273
+ decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
274
+ k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
275
+
276
+ # Update the cache
277
+ if len(self.k_states) <= layer_idx: # Initializing kv and k states
278
+ self.kv_states.append(kv_state.to(dtype))
279
+ self.k_states.append(k_state.to(dtype))
280
+
281
+ self.decode_kv_states.append(decode_kv_state.to(dtype))
282
+ self.decode_k_states.append(decode_k_state.to(dtype))
283
+
284
+ self.k_cache.append(key_states[:, :, -self.window_size:, :])
285
+ self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
286
+ # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
287
+ else:
288
+ # Update kv and k states recurrently
289
+ kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
290
+ k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
291
+ self.kv_states[layer_idx] = kv_state
292
+ self.k_states[layer_idx] = k_state
293
+
294
+ decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
295
+ + decode_kv_state).to(dtype)
296
+ decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
297
+ + decode_k_state).to(dtype)
298
+ self.decode_kv_states[layer_idx] = decode_kv_state
299
+ self.decode_k_states[layer_idx] = decode_k_state
300
+
301
+ self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
302
+ self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
303
+ self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
304
+
305
+ return self.kv_states[layer_idx], self.k_states[layer_idx]
306
+
307
+ def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
308
+ layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
309
+ """
310
+ Update the decoding KV and K states, and KV cache, during decodeing
311
+ """
312
+ with torch.no_grad():
313
+ k_cache = self.k_cache[layer_idx]
314
+ v_cache = self.v_cache[layer_idx]
315
+
316
+ if k_cache.shape[-2] < self.window_size: # build window-size cache
317
+ self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
318
+ self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
319
+ else:
320
+ # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
321
+ # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
322
+ # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
323
+ # else:
324
+ # f_k_state = feature_map_k(k_cache[:, :, :1, :])
325
+ # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
326
+ k_state = feature_map_k(k_cache[:, :, :1, :])
327
+ v_state = v_cache[:, :, :1, :]
328
+ kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
329
+ self.decode_kv_states[layer_idx] += kv_state
330
+ self.decode_k_states[layer_idx] += k_state
331
+
332
+ self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
333
+ self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
334
+
335
+ if layer_idx == 0:
336
+ self._seen_tokens += keys.shape[-2]
337
+ self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
338
+ return (self.k_cache[layer_idx], self.v_cache[layer_idx],
339
+ self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])
src/model/linear_attention/linear_window_attention_sw_linear.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Subquadratic attention combining sliding window and linear attentions
3
+ - Using "standard" sliding windows
4
+ - Didactically computes outputs with n^2 attention weights for now
5
+ - Copied + adapted from linear_window_attention_tk.py for single-file reference
6
+
7
+ For each layer:
8
+ - We first compute (softmax) attention over sliding windows
9
+ - We then compute standard linear attention to "fill in" the earlier parts
10
+ - We combine to model the entire sequence
11
+ """
12
+ from typing import List, Tuple, Optional, Callable
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from transformers.cache_utils import Cache
19
+ try:
20
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
21
+ except ModuleNotFoundError:
22
+ _flash_attention_forward = None # Transformers v4.36
23
+
24
+ # Causal linear attention dot product CUDA kernel from fast-transformers
25
+ from csrc import causal_dot_product
26
+
27
+ from src.model.rotary import apply_rotary_pos_emb
28
+ from .linear_attention import (
29
+ LolcatsLinearAttention, LinearAttentionState,
30
+ softmax_attention
31
+ )
32
+
33
+ # ----------------------
34
+ # Sliding window helpers
35
+ # ----------------------
36
+ def get_masks(window_size: int, q_len: int, k_len: int,
37
+ device: torch.device) -> tuple[torch.Tensor]:
38
+ """
39
+ Return masks for softmax and linear attention terms
40
+ -> 1 is include, 0 is ignore
41
+ """
42
+ kwargs = {'device': device, 'dtype': int}
43
+ causal_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0))
44
+ linear_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0) - window_size)
45
+ window_mask = causal_mask - linear_mask
46
+ # Return softmax mask (window), linear attention mask
47
+ # -> shapes broadcast over (b, h, q_len, k_len)
48
+ return window_mask[None, None, ...], linear_mask[None, None, ...]
49
+
50
+
51
+ def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
52
+ f_q: torch.Tensor, f_k: torch.Tensor,
53
+ v: torch.Tensor,
54
+ window_factor: torch.Tensor,
55
+ linear_factor: torch.Tensor,
56
+ window_size: int,
57
+ kv_state: torch.Tensor = None,
58
+ k_state: torch.Tensor = None,
59
+ eps: float = 1e-12,
60
+ mask_value: float=-1e8):
61
+ """
62
+ Hybrid attention combining sliding window and linear attentions
63
+ """
64
+
65
+ mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
66
+
67
+ # 1. Sliding window (softmax attention)
68
+ a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
69
+ a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
70
+ # torch.softmax(a_sm, dim=-1), but we account for the max when combining
71
+ a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
72
+ a_sm = window_factor * torch.exp(a_sm - a_sm_max)
73
+ sum_sm = a_sm.sum(dim=-1, keepdim=True)
74
+
75
+ # 2. Under window (linear attention)
76
+ a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
77
+ a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
78
+ sum_ln = a_ln.sum(dim=-1, keepdim=True)
79
+
80
+ # 3. Combine
81
+ a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
82
+ # Allow outputs to also depend on prior kv_state and k_state
83
+ y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
84
+ if kv_state is not None: # Combine with prior kv_state and k_state
85
+ y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
86
+ sum_ln += linear_factor * torch.einsum(
87
+ 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
88
+ y = (y / (sum_sm + sum_ln)).to(q.dtype)
89
+ return y, a # attention weights only for the last chunk
90
+
91
+
92
+ # ------------------------------
93
+ # Hybrid window attention linear
94
+ # ------------------------------
95
+ def under_window_linear_attention(f_q: torch.Tensor, f_k: torch.Tensor, v: torch.Tensor,
96
+ window_size: int, linear_factor: float, eps: float=1e-12):
97
+ """Compute hybrid window attention dot product with linear complexity in q_len"""
98
+ dtype = f_q.dtype
99
+ w = window_size
100
+ f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
101
+ v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
102
+ qkv = linear_factor * causal_dot_product(f_q.contiguous().to(dtype=torch.float32),
103
+ f_k.contiguous().to(dtype=torch.float32),
104
+ v.contiguous().to(dtype=torch.float32)).to(dtype=dtype)
105
+ sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
106
+ sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
107
+ sum_qk[sum_qk == 0] += eps
108
+ return qkv, sum_qk
109
+
110
+
111
+ def sliding_window_softmax_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
112
+ window_size: int, window_factor: float, mask_value: float=-1e8):
113
+ """
114
+ Compute sliding window softmax attention without materializing
115
+ O(seq_len^2) attention weights
116
+ """
117
+ d = q.shape[-1]
118
+ # Compute windows for keys
119
+ window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
120
+ k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
121
+ v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
122
+
123
+ # Compute windowed_softmax(qk); causal in its construction
124
+ a_sm = torch.einsum('bhld,bhldw->bhlw', q, k) * (d ** -0.5)
125
+ a_sm[a_sm == 0] = -torch.finfo(q.dtype).max # heuristic for zeroing out padding above
126
+ a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
127
+ a_sm = window_factor * torch.exp(a_sm - a_sm_max)
128
+ sum_sm = a_sm.sum(dim=-1, keepdim=True)
129
+ return torch.einsum('bhlw,bhldw->bhld', a_sm, v), sum_sm
130
+ # return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
131
+
132
+
133
+ def hybrid_attention_linear(q: torch.Tensor, k: torch.Tensor,
134
+ f_q: torch.Tensor, f_k: torch.Tensor,
135
+ v: torch.Tensor,
136
+ window_factor: torch.Tensor = None,
137
+ linear_factor: torch.Tensor = None,
138
+ window_size: int = 64,
139
+ kv_state: torch.Tensor = None,
140
+ k_state: torch.Tensor = None,
141
+ eps: float = 1e-12,
142
+ mask_value: float=-1e8):
143
+ """
144
+ Alternative hybrid attention combining sliding window and linear attentions
145
+ -> Uses O(n) memory if n is sequence length by padding and unfolding windows
146
+ """
147
+ window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
148
+ # 1. Sliding window (softmax attention)
149
+ with torch.no_grad():
150
+ qkv_sm, sum_qk_sm = sliding_window_softmax_attention(q, k, v, window_size, window_factor, mask_value)
151
+
152
+ # 2. Under window (linear attention)
153
+ qkv_ln, sum_qk_ln = under_window_linear_attention(f_q, f_k, v, window_size, linear_factor, eps)
154
+
155
+ # 3. Combine
156
+ y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
157
+ return y, None
158
+
159
+
160
+ # ---------------------
161
+ # Attention layer class
162
+ # ---------------------
163
+ class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
164
+ """
165
+ Lolcats attention combining sliding window and linear attention
166
+ """
167
+ def __init__(self,
168
+ window_size: int = 64,
169
+ decode_window_size: int = None,
170
+ affine_attention_factors: bool = False,
171
+ init_window_factor: float = 0,
172
+ train_window_factor: bool = True,
173
+ state_grad_enabled: bool = False,
174
+ **kwargs):
175
+ self.window_size = window_size
176
+ self.decode_window_size = (
177
+ decode_window_size if decode_window_size is not None else window_size
178
+ )
179
+ self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
180
+ super().__init__(**kwargs)
181
+ # Determine how we compute attentions
182
+ self.linear_attention = hybrid_attention_linear
183
+ self.attention_type = 'lolcats_llama_window_sw'
184
+ # Learnable factor for combining attentions
185
+ self.affine_attention_factors = affine_attention_factors
186
+ device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
187
+ if train_window_factor:
188
+ self.window_factors = nn.Parameter(
189
+ init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
190
+ else:
191
+ self.register_buffer(
192
+ "window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
193
+ )
194
+ # Whether we use original flash attention 2 inference (use during attention transfer)
195
+ self.base_inference = False
196
+ self.state_grad_enabled = state_grad_enabled
197
+
198
+ def forward(self,
199
+ hidden_states: torch.Tensor,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
+ position_ids: Optional[torch.LongTensor] = None,
202
+ past_key_value: Optional[Cache] = None,
203
+ output_attentions: bool = False,
204
+ use_cache: bool = False,
205
+ **kwargs,
206
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
207
+ """
208
+ Forward pass with the option to compute attention weights multiple ways
209
+ if self.train_attention is True
210
+ -> Consistent with HuggingFace Transformers for easy use with their pretrained models
211
+ """
212
+ b, l, _ = hidden_states.size()
213
+
214
+ if self.train_attention and self.base_inference:
215
+ with torch.no_grad():
216
+ _y_true = flash_attention_2(self, # self.base_attn,
217
+ hidden_states=hidden_states,
218
+ attention_mask=None,
219
+ position_ids=position_ids,
220
+ past_key_value=None,
221
+ output_attentions=False,
222
+ use_cache=False)[0]
223
+ # _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
224
+ y_true = _y_true.reshape(b, l, -1).contiguous()
225
+ y_true = self.o_proj(y_true)
226
+ # layer_io = (hidden_states, _y_true) # hack
227
+ layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
228
+ return y_true, layer_io, None
229
+
230
+ else:
231
+ q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
232
+ position_ids, past_key_value)
233
+ f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
234
+
235
+ attn_weights = None
236
+ # attention_mask = None # For now this is always True
237
+ if past_key_value is None: # Regular training
238
+ window_factors = F.sigmoid(self.window_factors)
239
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
240
+ y_true, a_pred = self.linear_attention(q, k, f_q, f_k, v,
241
+ window_factors, linear_factors,
242
+ window_size=self.window_size)
243
+ attn_weights = a_pred
244
+ else:
245
+ past_key_value.window_size = self.decode_window_size
246
+ if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
247
+ assert use_cache is True
248
+ _kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
249
+ self.feature_map_k,
250
+ dtype=q.dtype)
251
+ k_cache, v_cache, f_kv_state, f_k_state = _kv
252
+
253
+ # Sliding window + linear attention decode
254
+ window_factors = F.sigmoid(self.window_factors)
255
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
256
+
257
+ # Softmax attention terms
258
+ a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
259
+ a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
260
+ a_sm = window_factors * torch.exp(a_sm - a_sm_max)
261
+ sum_sm = a_sm.sum(dim=-1, keepdim=True)
262
+
263
+ # Combine with linear attention terms
264
+ y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
265
+ + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
266
+ sum_ln = linear_factors * torch.einsum(
267
+ 'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
268
+ y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
269
+
270
+ else: # Stateful training
271
+ try:
272
+ kv_state = past_key_value.kv_states[self.layer_idx]
273
+ k_state = past_key_value.k_states[self.layer_idx]
274
+ except IndexError:
275
+ kv_state, k_state = None, None
276
+ window_factors = F.sigmoid(self.window_factors)
277
+ linear_factors = 1 - window_factors if self.affine_attention_factors else 1
278
+ y_true, _ = self.linear_attention(q, k, f_q, f_k, v,
279
+ window_factors, linear_factors,
280
+ window_size=self.window_size,
281
+ kv_state=kv_state,
282
+ k_state=k_state)
283
+ # Save and update KV cache and states
284
+ # past_key_value.update(k, v.detach(), self.layer_idx,
285
+ # fmap_key_states=f_k.detach(),
286
+ # accumulate_in_fp32=True)
287
+ past_key_value.update(k, v, self.layer_idx,
288
+ fmap_key_states=f_k,
289
+ accumulate_in_fp32=True)
290
+ # Concatenate heads and apply output projection
291
+ _y_true = y_true.transpose(1, 2).contiguous()
292
+ y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
293
+
294
+ if self.train_attention:
295
+ attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
296
+ return y_true, attn_weights, past_key_value
297
+
298
+
299
+ class LinearAttentionSlidingWindowCache(LinearAttentionState):
300
+ """
301
+ Class for `past_key_values`
302
+ -> Alternative to KV cache; here we only maintain a "KV state" and "K state"
303
+ -> Modified from transformers.cache_utils.DynamicCache (v4.36)
304
+ """
305
+ def __init__(self, window_size: int = 64) -> None:
306
+ super().__init__()
307
+ self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
308
+ self._seen_tokens_by_layer: List[int] = []
309
+ self.kv_states: List[torch.Tensor] = []
310
+ self.k_states: List[torch.Tensor] = []
311
+
312
+ # Account for sliding windows
313
+ self.decode_kv_states: List[torch.Tensor] = []
314
+ self.decode_k_states: List[torch.Tensor] = []
315
+ self.k_cache: List[torch.Tensor] = []
316
+ self.v_cache: List[torch.Tensor] = []
317
+ self.window_size = window_size
318
+
319
+ def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
320
+ layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
321
+ accumulate_in_fp32: bool = False,
322
+ fmap_key_states: torch.Tensor = None, # should not be None
323
+ grad_enabled: bool = False,
324
+ **kwargs: any,
325
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
326
+ """
327
+ Update KV, K states; and KV cache during training
328
+ - For decoding, use `self.decode_kv_states` to keep track of KV states
329
+ up to sliding window terms
330
+ - For (chunked) training, use `self.kv_states` to keep track of KV states
331
+ up to end of sequence
332
+ - Likewise for `self.decode_k_states` and `self.k_states`
333
+ """
334
+ with torch.set_grad_enabled(grad_enabled):
335
+ if layer_idx == 0:
336
+ self._seen_tokens += key_states.shape[-2]
337
+
338
+ dtype = key_states.dtype
339
+ if accumulate_in_fp32:
340
+ # key_states = key_states.float()
341
+ fmap_key_states = fmap_key_states.float()
342
+ value_states = value_states.float()
343
+
344
+ # Decoding KV state (KV terms up to last window_size)
345
+ decode_kv_state = torch.einsum(
346
+ 'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
347
+ )
348
+ # KV state
349
+ kv_state = decode_kv_state + torch.einsum(
350
+ 'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
351
+ )
352
+ # shape is b, h, 1, f; note the 1
353
+ decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
354
+ k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
355
+
356
+ # Update the cache
357
+ if len(self.k_states) <= layer_idx: # Initializing kv and k states
358
+ self.kv_states.append(kv_state.to(dtype))
359
+ self.k_states.append(k_state.to(dtype))
360
+
361
+ self.decode_kv_states.append(decode_kv_state.to(dtype))
362
+ self.decode_k_states.append(decode_k_state.to(dtype))
363
+
364
+ self.k_cache.append(key_states[:, :, -self.window_size:, :])
365
+ self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
366
+ # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
367
+ else:
368
+ # Update kv and k states recurrently
369
+ kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
370
+ k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
371
+ self.kv_states[layer_idx] = kv_state
372
+ self.k_states[layer_idx] = k_state
373
+
374
+ decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
375
+ + decode_kv_state).to(dtype)
376
+ decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
377
+ + decode_k_state).to(dtype)
378
+ self.decode_kv_states[layer_idx] = decode_kv_state
379
+ self.decode_k_states[layer_idx] = decode_k_state
380
+
381
+ self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
382
+ self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
383
+ self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
384
+
385
+ return self.kv_states[layer_idx], self.k_states[layer_idx]
386
+
387
+ def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
388
+ layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
389
+ """
390
+ Update the decoding KV and K states, and KV cache, during decodeing
391
+ """
392
+ with torch.no_grad():
393
+ k_cache = self.k_cache[layer_idx]
394
+ v_cache = self.v_cache[layer_idx]
395
+
396
+ if k_cache.shape[-2] < self.window_size: # build window-size cache
397
+ self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
398
+ self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
399
+ else:
400
+ # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
401
+ # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
402
+ # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
403
+ # else:
404
+ # f_k_state = feature_map_k(k_cache[:, :, :1, :])
405
+ # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
406
+ k_state = feature_map_k(k_cache[:, :, :1, :])
407
+ v_state = v_cache[:, :, :1, :]
408
+ kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
409
+ self.decode_kv_states[layer_idx] += kv_state
410
+ self.decode_k_states[layer_idx] += k_state
411
+
412
+ self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
413
+ self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
414
+
415
+ if layer_idx == 0:
416
+ self._seen_tokens += keys.shape[-2]
417
+ self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
418
+ return (self.k_cache[layer_idx], self.v_cache[layer_idx],
419
+ self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])
420
+
421
+
422
+ # -----------------
423
+ # Flash Attention 2
424
+ # -----------------
425
+
426
+ def flash_attention_2(self,
427
+ hidden_states: torch.Tensor,
428
+ attention_mask: Optional[torch.LongTensor] = None,
429
+ position_ids: Optional[torch.LongTensor] = None,
430
+ past_key_value: Optional[Cache] = None,
431
+ output_attentions: bool = False,
432
+ use_cache: bool = False,
433
+ cache_position: Optional[torch.LongTensor] = None,
434
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
435
+ """
436
+ Wrapper for LlamaFlashAttention2
437
+ Copied and modified from HF Transformers v4.36 and v4.43 implementations
438
+ - (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
439
+ - (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
440
+ """
441
+ output_attentions = False
442
+
443
+ bsz, q_len, _ = hidden_states.size()
444
+
445
+ query_states = self.q_proj(hidden_states)
446
+ key_states = self.k_proj(hidden_states)
447
+ value_states = self.v_proj(hidden_states)
448
+
449
+ # Flash attention requires the input to have the shape
450
+ # batch_size x seq_length x head_dim x hidden_dim
451
+ # therefore we just need to keep the original shape
452
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
453
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
454
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
455
+
456
+ try: # As in Transformers v4.36
457
+ kv_seq_len = key_states.shape[-2]
458
+ if past_key_value is not None:
459
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
460
+ cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
461
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
462
+ except: # As in Transformers v4.39
463
+ cos, sin = self.rotary_emb(key_states, position_ids)
464
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
465
+
466
+ if past_key_value is not None:
467
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
468
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
469
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
470
+
471
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
472
+ # to be able to avoid many of these transpose/reshape/view.
473
+ query_states = query_states.transpose(1, 2)
474
+ key_states = key_states.transpose(1, 2)
475
+ value_states = value_states.transpose(1, 2)
476
+
477
+ dropout_rate = self.attention_dropout if self.training else 0.0
478
+
479
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
480
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
481
+ # cast them back in the correct dtype just to be sure everything works as expected.
482
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
483
+ # in fp32. (LlamaRMSNorm handles it correctly)
484
+
485
+ input_dtype = query_states.dtype
486
+ if input_dtype == torch.float32:
487
+ if torch.is_autocast_enabled():
488
+ target_dtype = torch.get_autocast_gpu_dtype()
489
+ # Handle the case where the model is quantized
490
+ elif hasattr(self.config, "_pre_quantization_dtype"):
491
+ target_dtype = self.config._pre_quantization_dtype
492
+ else:
493
+ target_dtype = self.q_proj.weight.dtype
494
+
495
+ logger.warning_once(
496
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
497
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
498
+ f" {target_dtype}."
499
+ )
500
+
501
+ query_states = query_states.to(target_dtype)
502
+ key_states = key_states.to(target_dtype)
503
+ value_states = value_states.to(target_dtype)
504
+
505
+ if getattr(self, '_flash_attention_forward', False):
506
+ attn_output = self._flash_attention_forward(
507
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
508
+ is_causal=True,
509
+ )
510
+ else:
511
+ attn_output = _flash_attention_forward(
512
+ query_states,
513
+ key_states,
514
+ value_states,
515
+ attention_mask,
516
+ q_len,
517
+ dropout=0, # dropout_rate,
518
+ sliding_window=getattr(self, "sliding_window", None),
519
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
520
+ is_causal=True,
521
+ )
522
+ return attn_output, past_key_value
src/model/linear_attention/linear_window_attention_sw_long.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoLCATs attention combining sliding window and linear attentions
3
+ - Using standard sliding window arrangement
4
+ - Training over long sequences with fixed memory with recurrent view
5
+ - During attention transfer, use Flash Attention to compute softmax attention outputs
6
+
7
+ For each layer:
8
+ - We first compute (softmax) attention over sliding windows
9
+ - We then compute standard linear attention to "fill in" the earlier parts
10
+ - We combine to model the entire sequence
11
+ """
12
+ from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
13
+ from .linear_window_attention_sw import hybrid_attention_quadratic
14
+
15
+
16
+ class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
17
+ """
18
+ Lolcats attention combining sliding window and linear attention
19
+ """
20
+ def __init__(self, remove_base_attn=True, **kwargs):
21
+ # keep self.base_attn for Flash Attention inference
22
+ super().__init__(remove_base_attn=True, **kwargs)
23
+ self.quadratic_attention = hybrid_attention_quadratic