def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ LightningModule hook that's used to restore things saved with on_save_checkpoint.""" if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder): if get_checkpoint_version(): assert ( checkpoint['checkpoint_version'] == get_checkpoint_version() ), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version' else: set_checkpoint_version(checkpoint['checkpoint_version']) logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}") return None
def start_training(self, trainer: 'Trainer') -> None: """ PTL Hook that is called after DPP is initialized. """ if self.lightning_module.has_megatron_encoder: app_state = AppState() if app_state.model_parallel_size is not None: # mpu grad clipping needs parameters to have the attribute model_parallel parameters = self.lightning_module.parameters() for p in parameters: if not hasattr(p, 'model_parallel'): p.model_parallel = False if get_checkpoint_version() is not None: # megatron checkpoint already restored pass elif trainer.resume_from_checkpoint is not None: # PTL auto-resuming, need to update checkpoint name # update path based on model parallel rank filepath = trainer.resume_from_checkpoint dirname = os.path.dirname(os.path.dirname(filepath)) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' trainer.resume_from_checkpoint = filepath logging.info( f'Resuming training from checkpoint {trainer.resume_from_checkpoint}' ) # need to set checkpoint version for megatron-lm checkpoint_version = torch.load( trainer.resume_from_checkpoint).get( 'checkpoint_version', None) if checkpoint_version is not None: set_checkpoint_version(checkpoint_version) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) else: self.lightning_module.restore_megatron_encoder_weights() else: if get_checkpoint_version() is not None: # megatron checkpoint already restored pass else: self.lightning_module.restore_megatron_encoder_weights() self.lightning_module.register_megatron_checkpoint_version() return super().start_training(trainer)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ LightningModule hook that's used to save things in addition to model weights. """ if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder): checkpoint['checkpoint_version'] = get_checkpoint_version() return None
def start_testing(self, trainer: 'Trainer') -> None: """ PTL Hook that is called after DPP is initialized. """ app_state = AppState() if app_state.model_parallel_size is not None: if self.has_megatron_encoder: # check megatron checkpoint version checkpoint_version = get_checkpoint_version() if checkpoint_version is None: raise ValueError("Unable to find megatron checkpoint version.") return super().start_testing(trainer)
def start_training(self, trainer: 'Trainer') -> None: """ PTL Hook that is called after DPP is initialized. """ if isinstance(self.lightning_module.bert_model, MegatronBertEncoder): app_state = AppState() if app_state.model_parallel_size is not None: # mpu grad clipping needs parameters to have the attribute model_parallel parameters = self.lightning_module.parameters() for p in parameters: if not hasattr(p, 'model_parallel'): p.model_parallel = False # TODO: figure out how to override clip gradients again # Update PTL trainer to use our _clip_gradients # self._trainer.accelerator_backend._clip_gradients = self._clip_gradients if get_checkpoint_version(): # Restored from .nemo, checkpoint_version will already be set pass elif trainer.resume_from_checkpoint is not None: # PTL auto-resuming, need to update checkpoint name # update path based on model parallel rank filepath = trainer.resume_from_checkpoint dirname = os.path.dirname(os.path.dirname(filepath)) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' trainer.resume_from_checkpoint = filepath logging.info( f'Resuming training from checkpoint {trainer.resume_from_checkpoint}' ) # need to set checkpoint version for megatron-lm checkpoint_version = torch.load( trainer.resume_from_checkpoint).get( 'checkpoint_version', None) if checkpoint_version is not None: set_checkpoint_version(checkpoint_version) else: logging.warning( 'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.' ) set_checkpoint_version(0) else: logging.info( f"Restoring from pretrained model parallel checkpoint: {self.lightning_module.bert_model._restore_path}" ) self.lightning_module.bert_model.restore_weights( self.lightning_module.bert_model._restore_path) self.lightning_module.register_megatron_checkpoint_version() return super().start_training(trainer)
def register_megatron_checkpoint_version(self): """ Adds checkpoint version to .nemo archive """ if self.has_megatron_encoder: checkpoint_version = get_checkpoint_version() if checkpoint_version is None: raise ValueError('Unable to get megatron checkpoint version.') else: checkpoint_version_dict = {'checkpoint_version': checkpoint_version} checkpoint_version_path = 'megatron_checkpoint_version.json' checkpoint_version_src = os.path.join(NEMO_NLP_TMP, checkpoint_version_path) with open(checkpoint_version_src, 'w') as f: f.write(json.dumps(checkpoint_version_dict)) self.register_artifact(checkpoint_version_path, checkpoint_version_src) else: raise ValueError('Registering Megatron checkpoint version but no Megatron encoder detected.')
def register_megatron_checkpoint_version(self): """ Adds checkpoint version to .nemo archive """ if self.bert_model is None: raise ValueError('Instantiate self.bert_model before registering megatron checkpoint version.') else: # get encoder config and create source for artifact if isinstance(self.bert_model, MegatronBertEncoder): checkpoint_version = get_checkpoint_version() if checkpoint_version is None: raise ValueError('Unable to get megatron checkpoint version.') else: checkpoint_version_dict = {'checkpoint_version': checkpoint_version} checkpoint_version_path = 'megatron_checkpoint_version.json' checkpoint_version_src = os.path.join(NEMO_NLP_TMP, checkpoint_version_path) with open(checkpoint_version_src, 'w') as f: f.write(json.dumps(checkpoint_version_dict)) self.register_artifact(checkpoint_version_path, checkpoint_version_src)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder): checkpoint['checkpoint_version'] = get_checkpoint_version() return None
def forward(self, hidden_states, attention_mask, rotary_pos_emb=None, layer_past=None, get_key_value=False): # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1.0: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) if exists(rotary_pos_emb): query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, rotary_pos_emb) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) if not self.sparse: # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocating result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm(matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout # =========================== if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) attention_scores += rpe # [1, np, sq, sk] # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) else: # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn] query_layer, key_layer, value_layer = map(lambda t: t.permute(1, 2, 0, 3).contiguous(), (query_layer, key_layer, value_layer)) # output shape [b, np(heads), sq, hn] attn_mask = attention_mask.to(query_layer.dtype) * -10000 if exists(self.rpe): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) else: rpe = None context_layer = self.sparse_attn(query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias