koukyo1994 commited on
Commit
a69bb58
·
verified ·
1 Parent(s): 071cf38

add LlamaActionV2

Browse files
Files changed (1) hide show
  1. modeling_llama_action.py +17 -0
modeling_llama_action.py CHANGED
@@ -238,3 +238,20 @@ class LlamaActionForCausalLM(LlamaForCausalLM):
238
  "past_key_values": past_key_values,
239
  "use_cache": use_cache,
240
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  "past_key_values": past_key_values,
239
  "use_cache": use_cache,
240
  }
241
+
242
+
243
+ class LlamaActionV2ForCausalLM(LlamaActionForCausalLM):
244
+ config_class = LlamaActionConfig
245
+
246
+ def __init__(self, config: LlamaActionConfig):
247
+ super().__init__(config)
248
+
249
+ self.action_projection = nn.Sequential(
250
+ nn.Linear(config.action_dim, config.hidden_size),
251
+ nn.ReLU(),
252
+ nn.Linear(config.hidden_size, config.hidden_size),
253
+ nn.ReLU(),
254
+ nn.Linear(config.hidden_size, config.hidden_size),
255
+ )
256
+
257
+ self.post_init()