nc-ai-consortium commited on
Commit
45481b8
·
verified ·
1 Parent(s): e7a7014

Upload folder using huggingface_hub

Browse files
configuration_vaetki.py CHANGED
@@ -3,7 +3,6 @@ from transformers.modeling_rope_utils import rope_config_validation
3
 
4
 
5
  class VaetkiConfig(PretrainedConfig):
6
-
7
  model_type = "vaetki"
8
  keys_to_ignore_at_inference = ["past_key_values"]
9
  base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
@@ -98,12 +97,10 @@ class VaetkiConfig(PretrainedConfig):
98
  self.rope_scaling = rope_scaling
99
  self.attention_bias = attention_bias
100
  self.attention_dropout = attention_dropout
101
- # Validate the correctness of rotary position embeddings parameters
102
- # BC: if there is a 'type' field, copy it it to 'rope_type'.
103
- if self.rope_scaling is not None and "type" in self.rope_scaling:
104
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
105
 
106
  if self.rope_scaling is not None:
 
 
107
  for key in ["beta_fast", "beta_slow", "factor"]:
108
  if key in self.rope_scaling:
109
  self.rope_scaling[key] = float(self.rope_scaling[key])
@@ -112,6 +109,7 @@ class VaetkiConfig(PretrainedConfig):
112
 
113
  if self.layer_types is None:
114
  self.layer_types = [
 
115
  "sliding_attention" if bool((i + 1) % 6) else "full_attention"
116
  for i in range(self.num_hidden_layers)
117
  ]
 
3
 
4
 
5
  class VaetkiConfig(PretrainedConfig):
 
6
  model_type = "vaetki"
7
  keys_to_ignore_at_inference = ["past_key_values"]
8
  base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
 
97
  self.rope_scaling = rope_scaling
98
  self.attention_bias = attention_bias
99
  self.attention_dropout = attention_dropout
 
 
 
 
100
 
101
  if self.rope_scaling is not None:
102
+ if self.rope_scaling["rope_type"] == "rope":
103
+ self.rope_scaling["rope_type"] = "default"
104
  for key in ["beta_fast", "beta_slow", "factor"]:
105
  if key in self.rope_scaling:
106
  self.rope_scaling[key] = float(self.rope_scaling[key])
 
109
 
110
  if self.layer_types is None:
111
  self.layer_types = [
112
+ # FIXME: megatron transformer_config에 맞게 패턴 변경 필요
113
  "sliding_attention" if bool((i + 1) % 6) else "full_attention"
114
  for i in range(self.num_hidden_layers)
115
  ]
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_vaetki.py CHANGED
@@ -13,7 +13,7 @@ from transformers.masking_utils import create_causal_mask, create_sliding_window
13
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
  from transformers.modeling_layers import GradientCheckpointingLayer
15
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
17
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
  from transformers.processing_utils import Unpack
19
  from transformers.utils import TransformersKwargs, can_return_tuple
@@ -38,25 +38,19 @@ class VaetkiRMSNorm(nn.Module):
38
 
39
 
40
  class VaetkiRotaryEmbedding(nn.Module):
41
- def __init__(self, config: VaetkiConfig, device=None):
42
  super().__init__()
43
- # BC: "rope_type" was originally "type"
44
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None and isinstance(config.rope_scaling, dict):
45
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
46
- else:
47
- self.rope_type = "default"
48
  self.max_seq_len_cached = config.max_position_embeddings
49
- self.original_max_seq_len = config.max_position_embeddings
50
 
51
  self.config = config
52
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
53
 
54
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
55
  self.register_buffer("inv_freq", inv_freq, persistent=False)
56
- self.original_inv_freq = self.inv_freq
57
 
58
  @torch.no_grad()
59
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
60
  def forward(self, x, position_ids):
61
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
62
  position_ids_expanded = position_ids[:, None, :].float()
@@ -258,10 +252,9 @@ class VaetkiAttention(nn.Module):
258
 
259
  self.scaling = self.qk_head_dim ** (-0.5)
260
  if self.config.rope_scaling is not None and not self.is_sliding:
261
- # TODO: check yarn related logic
262
  mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
263
- scaling_factor = self.config.rope_scaling["factor"]
264
  if mscale_all_dim:
 
265
  mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
266
  self.scaling = self.scaling * mscale * mscale
267
 
@@ -408,8 +401,7 @@ class VaetkiPreTrainedModel(PreTrainedModel):
408
  supports_gradient_checkpointing = True
409
  _no_split_modules = ["VaetkiDecoderLayer"]
410
  _skip_keys_device_placement = ["past_key_values"]
411
- _supports_flash_attn_3 = True
412
- _supports_flash_attn_2 = True
413
  _supports_sdpa = False
414
  _supports_flex_attn = False
415
  _supports_cache_class = True
@@ -445,13 +437,21 @@ class VaetkiModel(VaetkiPreTrainedModel):
445
  [VaetkiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
446
  )
447
  self.norm = VaetkiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
448
- self.rotary_emb_local = VaetkiRotaryEmbedding(config=config)
449
  self.gradient_checkpointing = False
 
 
450
 
451
  config = copy.deepcopy(config)
452
  config.rope_theta = config.rope_theta_global
453
- self.rotary_emb_global = VaetkiRotaryEmbedding(config=config)
454
- self.rotary_emb_global.inv_freq /= 8.0 # TODO: Possibly change in the future
 
 
 
 
 
 
 
455
 
456
  # Initialize weights and apply final processing
457
  self.post_init()
@@ -571,7 +571,7 @@ class VaetkiForCausalLM(VaetkiPreTrainedModel, GenerationMixin):
571
  super().__init__(config)
572
  self.model = VaetkiModel(config)
573
  self.vocab_size = config.vocab_size
574
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
575
 
576
  # Initialize weights and apply final processing
577
  self.post_init()
@@ -586,7 +586,7 @@ class VaetkiForCausalLM(VaetkiPreTrainedModel, GenerationMixin):
586
  return self.lm_head
587
 
588
  def set_output_embeddings(self, new_embeddings):
589
- self.lm_head = new_embeddings
590
 
591
  def set_decoder(self, decoder):
592
  self.model = decoder
@@ -633,7 +633,7 @@ class VaetkiForCausalLM(VaetkiPreTrainedModel, GenerationMixin):
633
  hidden_states = outputs.last_hidden_state
634
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
635
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
636
- logits = self.lm_head(hidden_states[:, slice_indices, :])
637
 
638
  loss = None
639
  if labels is not None:
 
13
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
  from transformers.modeling_layers import GradientCheckpointingLayer
15
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
17
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
  from transformers.processing_utils import Unpack
19
  from transformers.utils import TransformersKwargs, can_return_tuple
 
38
 
39
 
40
  class VaetkiRotaryEmbedding(nn.Module):
41
+ def __init__(self, config: VaetkiConfig, rope_type="default", original_max_position_embeddings=None, device=None):
42
  super().__init__()
43
+ self.rope_type = rope_type
 
 
 
 
44
  self.max_seq_len_cached = config.max_position_embeddings
45
+ self.original_max_seq_len = original_max_position_embeddings
46
 
47
  self.config = config
48
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
49
 
50
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
51
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
52
 
53
  @torch.no_grad()
 
54
  def forward(self, x, position_ids):
55
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
56
  position_ids_expanded = position_ids[:, None, :].float()
 
252
 
253
  self.scaling = self.qk_head_dim ** (-0.5)
254
  if self.config.rope_scaling is not None and not self.is_sliding:
 
255
  mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
 
256
  if mscale_all_dim:
257
+ scaling_factor = self.config.rope_scaling["factor"]
258
  mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
259
  self.scaling = self.scaling * mscale * mscale
260
 
 
401
  supports_gradient_checkpointing = True
402
  _no_split_modules = ["VaetkiDecoderLayer"]
403
  _skip_keys_device_placement = ["past_key_values"]
404
+ _supports_flash_attn = True
 
405
  _supports_sdpa = False
406
  _supports_flex_attn = False
407
  _supports_cache_class = True
 
437
  [VaetkiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
438
  )
439
  self.norm = VaetkiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
440
  self.gradient_checkpointing = False
441
+
442
+ self.rotary_emb_local = VaetkiRotaryEmbedding(config=config)
443
 
444
  config = copy.deepcopy(config)
445
  config.rope_theta = config.rope_theta_global
446
+ if self.config.rope_scaling is None:
447
+ rope_type = "default"
448
+ original_max_position_embeddings = config.max_position_embeddings
449
+ else:
450
+ rope_type = config.rope_scaling["rope_type"]
451
+ original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
452
+ self.rotary_emb_global = VaetkiRotaryEmbedding(config=config, rope_type=rope_type, original_max_position_embeddings=original_max_position_embeddings)
453
+ if rope_type == "default":
454
+ self.rotary_emb_global.inv_freq /= 8.0
455
 
456
  # Initialize weights and apply final processing
457
  self.post_init()
 
571
  super().__init__(config)
572
  self.model = VaetkiModel(config)
573
  self.vocab_size = config.vocab_size
574
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=torch.float32)
575
 
576
  # Initialize weights and apply final processing
577
  self.post_init()
 
586
  return self.lm_head
587
 
588
  def set_output_embeddings(self, new_embeddings):
589
+ self.lm_head = new_embeddings.to(torch.float32)
590
 
591
  def set_decoder(self, decoder):
592
  self.model = decoder
 
633
  hidden_states = outputs.last_hidden_state
634
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
635
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
636
+ logits = self.lm_head(hidden_states[:, slice_indices, :].to(self.lm_head.weight.dtype))
637
 
638
  loss = None
639
  if labels is not None: