Ejemplo n.º 1
0
    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        """ LightningModule hook that's used to restore things saved with on_save_checkpoint."""

        if hasattr(self, "bert_model") and isinstance(self.bert_model, MegatronBertEncoder):
            if get_checkpoint_version():
                assert (
                    checkpoint['checkpoint_version'] == get_checkpoint_version()
                ), 'checkpoint version found on_load_checkpoint different than get_checkpoint_version'
            else:
                set_checkpoint_version(checkpoint['checkpoint_version'])
                logging.info(f"Setting Megatron checkpoint version: {checkpoint['checkpoint_version']}")
        return None
Ejemplo n.º 2
0
    def start_training(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """

        if self.lightning_module.has_megatron_encoder:
            app_state = AppState()
            if app_state.model_parallel_size is not None:
                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self.lightning_module.parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                if get_checkpoint_version() is not None:
                    # megatron checkpoint already restored
                    pass
                elif trainer.resume_from_checkpoint is not None:
                    # PTL auto-resuming, need to update checkpoint name
                    # update path based on model parallel rank
                    filepath = trainer.resume_from_checkpoint
                    dirname = os.path.dirname(os.path.dirname(filepath))
                    basename = os.path.basename(filepath)
                    filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                    trainer.resume_from_checkpoint = filepath
                    logging.info(
                        f'Resuming training from checkpoint {trainer.resume_from_checkpoint}'
                    )
                    # need to set checkpoint version for megatron-lm
                    checkpoint_version = torch.load(
                        trainer.resume_from_checkpoint).get(
                            'checkpoint_version', None)
                    if checkpoint_version is not None:
                        set_checkpoint_version(checkpoint_version)
                    else:
                        logging.warning(
                            'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                        )
                        set_checkpoint_version(0)
                else:
                    self.lightning_module.restore_megatron_encoder_weights()
            else:
                if get_checkpoint_version() is not None:
                    # megatron checkpoint already restored
                    pass
                else:
                    self.lightning_module.restore_megatron_encoder_weights()

            self.lightning_module.register_megatron_checkpoint_version()

        return super().start_training(trainer)
Ejemplo n.º 3
0
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        """ LightningModule hook that's used to save things in addition to model weights. """

        if hasattr(self, "bert_model") and isinstance(self.bert_model,
                                                      MegatronBertEncoder):
            checkpoint['checkpoint_version'] = get_checkpoint_version()
        return None
Ejemplo n.º 4
0
    def start_testing(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """
        app_state = AppState()

        if app_state.model_parallel_size is not None:

            if self.has_megatron_encoder:
                # check megatron checkpoint version
                checkpoint_version = get_checkpoint_version()
                if checkpoint_version is None:
                    raise ValueError("Unable to find megatron checkpoint version.")

        return super().start_testing(trainer)
Ejemplo n.º 5
0
    def start_training(self, trainer: 'Trainer') -> None:
        """ PTL Hook that is called after DPP is initialized. """

        if isinstance(self.lightning_module.bert_model, MegatronBertEncoder):
            app_state = AppState()
            if app_state.model_parallel_size is not None:
                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self.lightning_module.parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                # TODO: figure out how to override clip gradients again
                # Update PTL trainer to use our _clip_gradients
                # self._trainer.accelerator_backend._clip_gradients = self._clip_gradients

                if get_checkpoint_version():
                    # Restored from .nemo, checkpoint_version will already be set
                    pass
                elif trainer.resume_from_checkpoint is not None:
                    # PTL auto-resuming, need to update checkpoint name
                    # update path based on model parallel rank
                    filepath = trainer.resume_from_checkpoint
                    dirname = os.path.dirname(os.path.dirname(filepath))
                    basename = os.path.basename(filepath)
                    filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                    trainer.resume_from_checkpoint = filepath
                    logging.info(
                        f'Resuming training from checkpoint {trainer.resume_from_checkpoint}'
                    )
                    # need to set checkpoint version for megatron-lm
                    checkpoint_version = torch.load(
                        trainer.resume_from_checkpoint).get(
                            'checkpoint_version', None)
                    if checkpoint_version is not None:
                        set_checkpoint_version(checkpoint_version)
                    else:
                        logging.warning(
                            'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                        )
                        set_checkpoint_version(0)
                else:
                    logging.info(
                        f"Restoring from pretrained model parallel checkpoint: {self.lightning_module.bert_model._restore_path}"
                    )
                    self.lightning_module.bert_model.restore_weights(
                        self.lightning_module.bert_model._restore_path)

            self.lightning_module.register_megatron_checkpoint_version()

        return super().start_training(trainer)
Ejemplo n.º 6
0
 def register_megatron_checkpoint_version(self):
     """ Adds checkpoint version to .nemo archive """
     if self.has_megatron_encoder:
         checkpoint_version = get_checkpoint_version()
         if checkpoint_version is None:
             raise ValueError('Unable to get megatron checkpoint version.')
         else:
             checkpoint_version_dict = {'checkpoint_version': checkpoint_version}
             checkpoint_version_path = 'megatron_checkpoint_version.json'
             checkpoint_version_src = os.path.join(NEMO_NLP_TMP, checkpoint_version_path)
             with open(checkpoint_version_src, 'w') as f:
                 f.write(json.dumps(checkpoint_version_dict))
             self.register_artifact(checkpoint_version_path, checkpoint_version_src)
     else:
         raise ValueError('Registering Megatron checkpoint version but no Megatron encoder detected.')
Ejemplo n.º 7
0
 def register_megatron_checkpoint_version(self):
     """ Adds checkpoint version to .nemo archive """
     if self.bert_model is None:
         raise ValueError('Instantiate self.bert_model before registering megatron checkpoint version.')
     else:
         # get encoder config and create source for artifact
         if isinstance(self.bert_model, MegatronBertEncoder):
             checkpoint_version = get_checkpoint_version()
             if checkpoint_version is None:
                 raise ValueError('Unable to get megatron checkpoint version.')
             else:
                 checkpoint_version_dict = {'checkpoint_version': checkpoint_version}
                 checkpoint_version_path = 'megatron_checkpoint_version.json'
                 checkpoint_version_src = os.path.join(NEMO_NLP_TMP, checkpoint_version_path)
                 with open(checkpoint_version_src, 'w') as f:
                     f.write(json.dumps(checkpoint_version_dict))
                 self.register_artifact(checkpoint_version_path, checkpoint_version_src)
Ejemplo n.º 8
0
 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
     if hasattr(self, "bert_model") and isinstance(self.bert_model,
                                                   MegatronBertEncoder):
         checkpoint['checkpoint_version'] = get_checkpoint_version()
     return None
Ejemplo n.º 9
0
    def forward(self, hidden_states, attention_mask, rotary_pos_emb=None, layer_past=None,
                get_key_value=False):
        # hidden_states: [sq, b, h]

        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer, _ = self.query_key_value(hidden_states)

        checkpoint_version = get_checkpoint_version()
        if checkpoint_version is not None:
            if checkpoint_version == 0:
                # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
            elif checkpoint_version == 1.0:
                # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
                           (self.num_attention_heads_per_partition,
                            3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        (query_layer,
         key_layer,
         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)

        if exists(rotary_pos_emb):
            query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, rotary_pos_emb)

        # ==================================
        # Adjust key and value for inference
        # ==================================

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
                                   key_layer), dim=0)
            value_layer = torch.cat((past_value.type_as(value_layer),
                                     value_layer), dim=0)
        if get_key_value:
            present = (key_layer, value_layer)

        if not self.sparse:
            # ===================================
            # Raw attention scores. [b, np, s, s]
            # ===================================

            # [b, np, sq, sk]
            output_size = (query_layer.size(1),
                           query_layer.size(2),
                           query_layer.size(0),
                           key_layer.size(0))

            # [sq, b, np, hn] -> [sq, b * np, hn]
            query_layer = query_layer.view(output_size[2],
                                           output_size[0] * output_size[1], -1)
            key_layer = key_layer.view(output_size[3],
                                       output_size[0] * output_size[1], -1)

            # preallocating result tensor: [b * np, sq, sk]
            matmul_result = torch.empty(
                output_size[0] * output_size[1],
                output_size[2],
                output_size[3],
                dtype=query_layer.dtype,
                device=torch.cuda.current_device())

            # Raw attention scores. [b * np, sq, sk]
            matmul_result = torch.baddbmm(matmul_result,
                                          query_layer.transpose(0, 1),  # [b * np, sq, hn]
                                          key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                                          beta=0.0, alpha=(1.0 / self.norm_factor))

            # change view to [b, np, sq, sk]
            attention_scores = matmul_result.view(*output_size)

            # ==================================================
            # Update attention mask for inference. [b, np, sq, sk]
            # ==================================================

            if get_key_value:
                with torch.no_grad():
                    if layer_past is not None:
                        attention_mask = attention_mask[
                                         ...,
                                         attention_scores.size(3) - 1,
                                         :attention_scores.size(3)].unsqueeze(2)
                    else:
                        attention_mask = attention_mask[
                                         ...,
                                         :attention_scores.size(3),
                                         :attention_scores.size(3)]

            # ===========================
            # Attention probs and dropout
            # ===========================

            if exists(self.rpe):
                rpe = self.rpe(query_layer.size(0), key_layer.size(0))
                attention_scores += rpe  # [1, np, sq, sk]

            # attention scores and attention mask [b, np, sq, sk]
            attention_probs = self.scale_mask_softmax(attention_scores,
                                                      attention_mask)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            with mpu.get_cuda_rng_tracker().fork():
                attention_probs = self.attention_dropout(attention_probs)

            # =========================
            # Context layer. [sq, b, hp]
            # =========================

            # value_layer -> context layer.
            # [sk, b, np, hn] --> [b, np, sq, hn]

            # context layer shape: [b, np, sq, hn]
            output_size = (value_layer.size(1),
                           value_layer.size(2),
                           query_layer.size(0),
                           value_layer.size(3))

            # change view [sk, b * np, hn]
            value_layer = value_layer.view(value_layer.size(0),
                                           output_size[0] * output_size[1], -1)

            # change view [b * np, sq, sk]
            attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                                   output_size[2], -1)

            # matmul: [b * np, sq, hn]
            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

            # change view [b, np, sq, hn]
            context_layer = context_layer.view(*output_size)
        else:
            # shape of q/k/v is [sq, b, np, hn] and needs to be transposed to [b, np, sq, hn]
            query_layer, key_layer, value_layer = map(lambda t: t.permute(1, 2, 0, 3).contiguous(),
                                                      (query_layer, key_layer,
                                                       value_layer))
            # output shape [b, np(heads), sq, hn]
            attn_mask = attention_mask.to(query_layer.dtype) * -10000
            if exists(self.rpe):
                rpe = self.rpe(query_layer.size(0), key_layer.size(0))
            else:
                rpe = None
            context_layer = self.sparse_attn(query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
                                  (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
        # Output. [sq, b, h]
        # =================

        output, bias = self.dense(context_layer)

        if get_key_value:
            output = [output, present]

        return output, bias