Exemple #1
0
 def setup_test_data(self, test_data_config: Optional[DictConfig]):
     self._test_dl = self._setup_eval_dataloader_from_config(cfg=test_data_config)
     # instantiate Torchmetric for each test dataloader
     if self._test_dl is not None:
         for dataloader_idx in range(len(self._test_dl)):
             if dataloader_idx == 0:
                 setattr(
                     self, f'test_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True),
                 )
             else:
                 setattr(
                     self,
                     f'test_loss_{dataloader_idx}',
                     GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True),
                 )
Exemple #2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Instantiates tokenizer and register to be saved with NeMo Model archive
        # After this call, ther will be self.tokenizer which can convert between tokens and token_ids.
        self.setup_tokenizer(
            tokenizer_name=cfg.tokenizer.get("tokenizer_name", "yttm"),
            tokenizer_model=cfg.tokenizer.get("tokenizer_model", None),
            vocab_file=cfg.tokenizer.get("vocab_file", None),
            bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0),
            special_tokens=cfg.tokenizer.get("special_tokens", {}))

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm',
                encoder_cfg_dict.get('pre_ln', True)),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.encoder.hidden_size,
            num_classes=vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight

        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id, label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
        self.eval_ppl = SequencePerplexity()
Exemple #3
0
class TransformerLMModel(ModelPT):
    """
    Left-to-right Transformer language model.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Instantiates tokenizer and register to be saved with NeMo Model archive
        # After this call, ther will be self.tokenizer which can convert between tokens and token_ids.
        self.setup_tokenizer(
            tokenizer_name=cfg.tokenizer.get("tokenizer_name", "yttm"),
            tokenizer_model=cfg.tokenizer.get("tokenizer_model", None),
            vocab_file=cfg.tokenizer.get("vocab_file", None),
            bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0),
            special_tokens=cfg.tokenizer.get("special_tokens", {}))

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm',
                encoder_cfg_dict.get('pre_ln', True)),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.encoder.hidden_size,
            num_classes=vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight

        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id, label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
        self.eval_ppl = SequencePerplexity()

    @typecheck()
    def forward(self, input_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """

        hidden_states = self.encoder(input_ids=input_ids,
                                     encoder_mask=attention_mask)
        log_probs = self.log_softmax(hidden_states=hidden_states)

        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1
                # added by DataLoader is excess.
                batch[i] = batch[i].squeeze(dim=0)
        ids, mask = batch
        input_ids, labels = ids[:, :-1], ids[:, 1:]
        input_mask = mask[:, :-1]
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)

        train_loss = self.loss_fn(log_probs=log_probs, labels=labels)

        tensorboard_logs = {
            "train_loss": train_loss,
            "lr": self._optimizer.param_groups[0]["lr"],
        }
        return {"loss": train_loss, "log": tensorboard_logs}

    def eval_step(self, batch, batch_idx):
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1
                # added by DataLoader is excess.
                batch[i] = batch[i].squeeze(dim=0)
        ids, mask = batch
        input_ids, labels = ids[:, :-1], ids[:, 1:]
        input_mask, output_mask = mask[:, :-1], mask[:, 1:]
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)
        eval_loss = self.eval_loss_fn(log_probs=log_probs, labels=labels)

        self.eval_loss(loss=eval_loss,
                       num_measurements=log_probs.shape[0] *
                       log_probs.shape[1])
        self.eval_ppl(log_probs=log_probs, labels=labels, mask=output_mask)
        return {}

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        return self.eval_step(batch, batch_idx)

    def eval_epoch_end(self, outputs, mode):
        eval_loss = self.eval_loss.compute()
        eval_ppl = self.eval_ppl.compute()
        self.log(f"{mode}_loss", eval_loss, sync_dist=True)
        self.log(f"{mode}_PPL", eval_ppl, sync_dist=True)
        dataset_name = "Validation" if mode == 'val' else "Test"
        logging.info(
            f"\n\n\n\n{dataset_name} PPL: {np.round(eval_ppl.item(), 2)}")

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        self.eval_epoch_end(outputs, 'val')
        self.eval_loss.reset()
        self.eval_ppl.reset()

    def test_epoch_end(self, outputs):
        self.eval_epoch_end(outputs, 'test')

    def setup_tokenizer(self,
                        tokenizer_name=None,
                        tokenizer_model=None,
                        vocab_file=None,
                        bpe_dropout=0.0,
                        special_tokens=None):

        supported_tokenizers = ['yttm', 'huggingface', 'sentencepiece', 'word']
        if tokenizer_name not in supported_tokenizers:
            raise NotImplementedError(
                f"Currently we only support tokenizers in {supported_tokenizers}."
            )

        self.tokenizer = get_tokenizer(
            tokenizer_name=tokenizer_name,
            tokenizer_model=self.register_artifact(
                "cfg.tokenizer.tokenizer_model", tokenizer_model),
            vocab_file=vocab_file,
            bpe_dropout=bpe_dropout,
            special_tokens=dict(special_tokens or {}),
            use_fast=False,
        )

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0):

        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            else:
                metadata_file = cfg.get('metadata_file')
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if cfg.get('tar_files') is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                    else:
                        raise FileNotFoundError(
                            "Could not find tarred dataset in config or metadata."
                        )
                else:
                    tar_files = cfg.get('tar_files')
                    if metadata.get('tar_files') is not None:
                        raise ValueError(
                            'Tar files specified in config and in metadata file. Tar files should only be specified once.'
                        )
            dataset = TarredSentenceDataset(
                text_tar_filepaths=tar_files,
                metadata_path=metadata_file,
                tokenizer=self.tokenizer,
                shuffle_n=cfg.get("tar_shuffle_n", 100),
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            dataset = SentenceDataset(
                tokenizer=self.tokenizer,
                dataset=cfg.file_name,
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                cache_ids=cfg.get("cache_ids", False),
            )
        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
Exemple #4
0
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_name=cfg.encoder_tokenizer.tokenizer_name,
            encoder_tokenizer_model=cfg.encoder_tokenizer.tokenizer_model,
            encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
            decoder_tokenizer_name=cfg.decoder_tokenizer.tokenizer_name,
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
        )

        # After this call, the model will have  self.source_processor and self.target_processor objects
        self.setup_pre_and_post_processing_utils(source_lang=self.src_language,
                                                 target_lang=self.tgt_language)

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # TODO: use get_encoder function with support for HF and Megatron
        self.encoder = TransformerEncoderNM(
            vocab_size=self.encoder_vocab_size,
            hidden_size=cfg.encoder.hidden_size,
            num_layers=cfg.encoder.num_layers,
            inner_size=cfg.encoder.inner_size,
            max_sequence_length=cfg.encoder.max_sequence_length if hasattr(
                cfg.encoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.encoder.embedding_dropout if hasattr(
                cfg.encoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.encoder.learn_positional_encodings
            if hasattr(cfg.encoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.encoder.num_attention_heads,
            ffn_dropout=cfg.encoder.ffn_dropout,
            attn_score_dropout=cfg.encoder.attn_score_dropout,
            attn_layer_dropout=cfg.encoder.attn_layer_dropout,
            hidden_act=cfg.encoder.hidden_act,
            mask_future=cfg.encoder.mask_future,
            pre_ln=cfg.encoder.pre_ln,
        )

        # TODO: user get_decoder function with support for HF and Megatron
        self.decoder = TransformerDecoderNM(
            vocab_size=self.decoder_vocab_size,
            hidden_size=cfg.decoder.hidden_size,
            num_layers=cfg.decoder.num_layers,
            inner_size=cfg.decoder.inner_size,
            max_sequence_length=cfg.decoder.max_sequence_length if hasattr(
                cfg.decoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.decoder.embedding_dropout if hasattr(
                cfg.decoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.decoder.learn_positional_encodings
            if hasattr(cfg.decoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.decoder.num_attention_heads,
            ffn_dropout=cfg.decoder.ffn_dropout,
            attn_score_dropout=cfg.decoder.attn_score_dropout,
            attn_layer_dropout=cfg.decoder.attn_layer_dropout,
            hidden_act=cfg.decoder.hidden_act,
            pre_ln=cfg.decoder.pre_ln,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
Exemple #5
0
class MTEncDecModel(EncDecNLPModel):
    """
    Encoder-decoder machine translation model.
    """
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_name=cfg.encoder_tokenizer.tokenizer_name,
            encoder_tokenizer_model=cfg.encoder_tokenizer.tokenizer_model,
            encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
            decoder_tokenizer_name=cfg.decoder_tokenizer.tokenizer_name,
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
        )

        # After this call, the model will have  self.source_processor and self.target_processor objects
        self.setup_pre_and_post_processing_utils(source_lang=self.src_language,
                                                 target_lang=self.tgt_language)

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # TODO: use get_encoder function with support for HF and Megatron
        self.encoder = TransformerEncoderNM(
            vocab_size=self.encoder_vocab_size,
            hidden_size=cfg.encoder.hidden_size,
            num_layers=cfg.encoder.num_layers,
            inner_size=cfg.encoder.inner_size,
            max_sequence_length=cfg.encoder.max_sequence_length if hasattr(
                cfg.encoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.encoder.embedding_dropout if hasattr(
                cfg.encoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.encoder.learn_positional_encodings
            if hasattr(cfg.encoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.encoder.num_attention_heads,
            ffn_dropout=cfg.encoder.ffn_dropout,
            attn_score_dropout=cfg.encoder.attn_score_dropout,
            attn_layer_dropout=cfg.encoder.attn_layer_dropout,
            hidden_act=cfg.encoder.hidden_act,
            mask_future=cfg.encoder.mask_future,
            pre_ln=cfg.encoder.pre_ln,
        )

        # TODO: user get_decoder function with support for HF and Megatron
        self.decoder = TransformerDecoderNM(
            vocab_size=self.decoder_vocab_size,
            hidden_size=cfg.decoder.hidden_size,
            num_layers=cfg.decoder.num_layers,
            inner_size=cfg.decoder.inner_size,
            max_sequence_length=cfg.decoder.max_sequence_length if hasattr(
                cfg.decoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.decoder.embedding_dropout if hasattr(
                cfg.decoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.decoder.learn_positional_encodings
            if hasattr(cfg.decoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.decoder.num_attention_heads,
            ffn_dropout=cfg.decoder.ffn_dropout,
            attn_score_dropout=cfg.decoder.attn_score_dropout,
            attn_layer_dropout=cfg.decoder.attn_layer_dropout,
            hidden_act=cfg.decoder.hidden_act,
            pre_ln=cfg.decoder.pre_ln,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)

    def filter_predicted_ids(self, ids):
        ids[ids >=
            self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id
        return ids

    @typecheck()
    def forward(self, src, src_mask, tgt, tgt_mask):
        src_hiddens = self.encoder(src, src_mask)
        tgt_hiddens = self.decoder(tgt, tgt_mask, src_hiddens, src_mask)
        log_probs = self.log_softmax(hidden_states=tgt_hiddens)
        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)
        train_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        tensorboard_logs = {
            'train_loss': train_loss,
            'lr': self._optimizer.param_groups[0]['lr'],
        }
        return {'loss': train_loss, 'log': tensorboard_logs}

    def eval_step(self, batch, batch_idx, mode):
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)

        # this will run encoder twice -- TODO: potentially fix
        _, translations = self.batch_translate(src=src_ids, src_mask=src_mask)
        eval_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        self.eval_loss(loss=eval_loss,
                       num_measurements=log_probs.shape[0] *
                       log_probs.shape[1])
        np_tgt = tgt_ids.cpu().numpy()
        ground_truths = [
            self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt
        ]
        ground_truths = [
            self.target_processor.detokenize(tgt.split(' '))
            for tgt in ground_truths
        ]
        num_non_pad_tokens = np.not_equal(
            np_tgt, self.decoder_tokenizer.pad_id).sum().item()
        return {
            'translations': translations,
            'ground_truths': ground_truths,
            'num_non_pad_tokens': num_non_pad_tokens,
        }

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, 'test')

    @rank_zero_only
    def log_param_stats(self):
        for name, p in self.named_parameters():
            if p.requires_grad:
                self.trainer.logger.experiment.add_histogram(
                    name + '_hist', p, global_step=self.global_step)
                self.trainer.logger.experiment.add_scalars(
                    name,
                    {
                        'mean': p.mean(),
                        'stddev': p.std(),
                        'max': p.max(),
                        'min': p.min()
                    },
                    global_step=self.global_step,
                )

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        return self.eval_step(batch, batch_idx, 'val')

    def eval_epoch_end(self, outputs, mode):
        eval_loss = self.eval_loss.compute()
        translations = list(
            itertools.chain(*[x['translations'] for x in outputs]))
        ground_truths = list(
            itertools.chain(*[x['ground_truths'] for x in outputs]))

        assert len(translations) == len(ground_truths)
        if self.tgt_language in ['ja']:
            sacre_bleu = corpus_bleu(translations, [ground_truths],
                                     tokenize="ja-mecab")
        elif self.tgt_language in ['zh']:
            sacre_bleu = corpus_bleu(translations, [ground_truths],
                                     tokenize="zh")
        else:
            sacre_bleu = corpus_bleu(translations, [ground_truths],
                                     tokenize="13a")

        dataset_name = "Validation" if mode == 'val' else "Test"
        logging.info(f"\n\n\n\n{dataset_name} set size: {len(translations)}")
        logging.info(f"{dataset_name} Sacre BLEU = {sacre_bleu.score}")
        logging.info(f"{dataset_name} TRANSLATION EXAMPLES:".upper())
        for i in range(0, 3):
            ind = random.randint(0, len(translations) - 1)
            logging.info("    " + '\u0332'.join(f"EXAMPLE {i}:"))
            logging.info(f"    Prediction:   {translations[ind]}")
            logging.info(f"    Ground Truth: {ground_truths[ind]}")

        ans = {
            f"{mode}_loss": eval_loss,
            f"{mode}_sacreBLEU": sacre_bleu.score
        }
        ans['log'] = dict(ans)
        return ans

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        self.log_dict(self.eval_epoch_end(outputs, 'val'))

    def test_epoch_end(self, outputs):
        return self.eval_epoch_end(outputs, 'test')

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("load_from_cached_dataset", False):
            logging.info('Loading from cached dataset %s' %
                         (cfg.src_file_name))
            if cfg.src_file_name != cfg.tgt_file_name:
                raise ValueError(
                    "src must be equal to target for cached dataset")
            dataset = pickle.load(open(cfg.src_file_name, 'rb'))
            dataset.reverse_lang_direction = cfg.get("reverse_lang_direction",
                                                     False)
        elif cfg.get("use_tarred_dataset", False):
            if cfg.get('tar_files') is None:
                raise FileNotFoundError("Could not find tarred dataset.")
            logging.info(f'Loading from tarred dataset {cfg.get("tar_files")}')
            if cfg.get("metadata_file", None) is None:
                raise FileNotFoundError(
                    "Could not find metadata path in config")
            dataset = TarredTranslationDataset(
                text_tar_filepaths=cfg.tar_files,
                metadata_path=cfg.metadata_file,
                encoder_tokenizer=self.encoder_tokenizer,
                decoder_tokenizer=self.decoder_tokenizer,
                shuffle_n=cfg.get("tar_shuffle_n", 100),
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
                reverse_lang_direction=cfg.get("reverse_lang_direction",
                                               False),
            )
            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            dataset = TranslationDataset(
                dataset_src=str(Path(cfg.src_file_name).expanduser()),
                dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction",
                                               False),
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    def setup_pre_and_post_processing_utils(self, source_lang, target_lang):
        """
        Creates source and target processor objects for input and output pre/post-processing.
        """
        self.source_processor, self.target_processor = None, None
        if (source_lang == 'en'
                and target_lang == 'ja') or (source_lang == 'ja'
                                             and target_lang == 'en'):
            self.source_processor = EnJaProcessor(source_lang)
            self.target_processor = EnJaProcessor(target_lang)
        else:
            if source_lang == 'zh':
                self.source_processor = ChineseProcessor()
            if target_lang == 'zh':
                self.target_processor = ChineseProcessor()
            if source_lang is not None and source_lang not in ['ja', 'zh']:
                self.source_processor = MosesProcessor(source_lang)
            if target_lang is not None and target_lang not in ['ja', 'zh']:
                self.target_processor = MosesProcessor(target_lang)

    @torch.no_grad()
    def batch_translate(
        self,
        src: torch.LongTensor,
        src_mask: torch.LongTensor,
    ):
        """	
        Translates a minibatch of inputs from source language to target language.	
        Args:	
            src: minibatch of inputs in the src language (batch x seq_len)	
            src_mask: mask tensor indicating elements to be ignored (batch x seq_len)	
        Returns:	
            translations: a list strings containing detokenized translations	
            inputs: a list of string containing detokenized inputs	
        """
        mode = self.training
        try:
            self.eval()

            src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
            beam_results = self.beam_search(encoder_hidden_states=src_hiddens,
                                            encoder_input_mask=src_mask)
            beam_results = self.filter_predicted_ids(beam_results)

            translations = [
                self.decoder_tokenizer.ids_to_text(tr)
                for tr in beam_results.cpu().numpy()
            ]
            inputs = [
                self.encoder_tokenizer.ids_to_text(inp)
                for inp in src.cpu().numpy()
            ]
            if self.target_processor is not None:
                translations = [
                    self.target_processor.detokenize(translation.split(' '))
                    for translation in translations
                ]

            if self.source_processor is not None:
                inputs = [
                    self.source_processor.detokenize(item.split(' '))
                    for item in inputs
                ]

        finally:
            self.train(mode=mode)
        return inputs, translations

    # TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language
    @torch.no_grad()
    def translate(self,
                  text: List[str],
                  source_lang: str = None,
                  target_lang: str = None) -> List[str]:
        """
        Translates list of sentences from source language to target language.
        Should be regular text, this method performs its own tokenization/de-tokenization
        Args:
            text: list of strings to translate
            source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run
            target_lang: if not None, corresponding MosesDecokenizer will be run
        Returns:
            list of translated strings
        """
        # __TODO__: This will reset both source and target processors even if you want to reset just one.
        if source_lang is not None or target_lang is not None:
            self.setup_pre_and_post_processing_utils(source_lang, target_lang)

        mode = self.training
        try:
            self.eval()
            inputs = []
            for txt in text:
                if self.source_processor is not None:
                    txt = self.source_processor.normalize(txt)
                    txt = self.source_processor.tokenize(txt)
                ids = self.encoder_tokenizer.text_to_ids(txt)
                ids = [self.encoder_tokenizer.bos_id
                       ] + ids + [self.encoder_tokenizer.eos_id]
                inputs.append(ids)
            max_len = max(len(txt) for txt in inputs)
            src_ids_ = np.ones(
                (len(inputs), max_len)) * self.encoder_tokenizer.pad_id
            for i, txt in enumerate(inputs):
                src_ids_[i][:len(txt)] = txt

            src_mask = torch.FloatTensor(
                (src_ids_ != self.encoder_tokenizer.pad_id)).to(self.device)
            src = torch.LongTensor(src_ids_).to(self.device)
            _, translations = self.batch_translate(src, src_mask)
        finally:
            self.train(mode=mode)
        return translations

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
Exemple #6
0
def _loss_class_test(
    rank: int,
    worldsize: int,
    loss_sum_or_avg: Optional[torch.Tensor],
    num_measurements: Optional[torch.Tensor],
    dist_sync_on_step: bool,
    take_avg_loss: bool,
    check_dist_sync_on_step: bool = True,
    check_batch: bool = True,
    atol: float = 1e-8,
):
    """ Utility function doing the actual comparison between lightning class metric
        and reference metric.
        Args:
            rank: rank of current process
            worldsize: number of processes
            loss_sum_or_avg: a one dimensional float torch tensor with loss sums or means.
            num_measurements: a one dimensional integer torch tensor with number of values on which sums or means from
                ``loss_sum_or_avg`` were computed.
            dist_sync_on_step: bool, if true will synchronize metric state across processes at each call of the
                method :meth:`forward()`
            take_avg_loss: dict with additional arguments used for class initialization
            check_dist_sync_on_step: bool, if true will check if the metric is also correctly
                calculated per batch per device (and not just at the end)
            check_batch: bool, if true will check if the metric is also correctly
                calculated across devices for each batch (and not just at the end)
    """
    # Instantiate lightning metric
    loss_metric = GlobalAverageLossMetric(compute_on_step=True,
                                          dist_sync_on_step=dist_sync_on_step,
                                          take_avg_loss=take_avg_loss)

    # verify loss works after being loaded from pickled state
    pickled_metric = pickle.dumps(loss_metric)
    loss_metric = pickle.loads(pickled_metric)
    for i in range(rank, NUM_BATCHES, worldsize):
        batch_result = loss_metric(loss_sum_or_avg[i], num_measurements[i])
        if loss_metric.dist_sync_on_step:
            if rank == 0:
                ddp_loss_sum_or_avg = torch.stack(
                    [loss_sum_or_avg[i + r] for r in range(worldsize)])
                ddp_num_measurements = torch.stack(
                    [num_measurements[i + r] for r in range(worldsize)])
                sk_batch_result = reference_loss_func(ddp_loss_sum_or_avg,
                                                      ddp_num_measurements,
                                                      take_avg_loss)
                # assert for dist_sync_on_step
                if check_dist_sync_on_step:
                    if sk_batch_result.isnan():
                        assert batch_result.isnan()
                    else:
                        assert np.allclose(
                            batch_result.numpy(), sk_batch_result, atol=atol
                        ), f"batch_result = {batch_result.numpy()}, sk_batch_result = {sk_batch_result}, i = {i}"
        else:
            ls = loss_sum_or_avg[i:i + 1]
            nm = num_measurements[i:i + 1]
            sk_batch_result = reference_loss_func(ls, nm, take_avg_loss)
            # assert for batch
            if check_batch:
                if sk_batch_result.isnan():
                    assert batch_result.isnan()
                else:
                    assert np.allclose(
                        batch_result.numpy(), sk_batch_result, atol=atol
                    ), f"batch_result = {batch_result.numpy()}, sk_batch_result = {sk_batch_result}, i = {i}"
    # check on all batches on all ranks
    result = loss_metric.compute()
    assert isinstance(result, torch.Tensor)
    sk_result = reference_loss_func(loss_sum_or_avg, num_measurements,
                                    take_avg_loss)

    # assert after aggregation
    if sk_result.isnan():
        assert result.isnan()
    else:
        assert np.allclose(
            result.numpy(), sk_result,
            atol=atol), f"result = {result.numpy()}, sk_result = {sk_result}"
Exemple #7
0
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)
        self.setup_enc_dec_tokenizers(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        # TODO: use get_encoder function with support for HF and Megatron
        self.encoder = TransformerEncoderNM(
            vocab_size=self.encoder_vocab_size,
            hidden_size=cfg.encoder.hidden_size,
            num_layers=cfg.encoder.num_layers,
            inner_size=cfg.encoder.inner_size,
            max_sequence_length=cfg.encoder.max_sequence_length if hasattr(
                cfg.encoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.encoder.embedding_dropout if hasattr(
                cfg.encoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.encoder.learn_positional_encodings
            if hasattr(cfg.encoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.encoder.num_attention_heads,
            ffn_dropout=cfg.encoder.ffn_dropout,
            attn_score_dropout=cfg.encoder.attn_score_dropout,
            attn_layer_dropout=cfg.encoder.attn_layer_dropout,
            hidden_act=cfg.encoder.hidden_act,
            mask_future=cfg.encoder.mask_future,
            pre_ln=cfg.encoder.pre_ln,
        )

        # TODO: user get_decoder function with support for HF and Megatron
        self.decoder = TransformerDecoderNM(
            vocab_size=self.decoder_vocab_size,
            hidden_size=cfg.decoder.hidden_size,
            num_layers=cfg.decoder.num_layers,
            inner_size=cfg.decoder.inner_size,
            max_sequence_length=cfg.decoder.max_sequence_length if hasattr(
                cfg.decoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.decoder.embedding_dropout if hasattr(
                cfg.decoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.decoder.learn_positional_encodings
            if hasattr(cfg.decoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.decoder.num_attention_heads,
            ffn_dropout=cfg.decoder.ffn_dropout,
            attn_score_dropout=cfg.decoder.attn_score_dropout,
            attn_layer_dropout=cfg.decoder.attn_layer_dropout,
            hidden_act=cfg.decoder.hidden_act,
            pre_ln=cfg.decoder.pre_ln,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
Exemple #8
0
class MTEncDecModel(EncDecNLPModel):
    """
    Encoder-decoder machine translation model.
    """

    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        if trainer is not None:
            self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)
        self.setup_enc_dec_tokenizers(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # TODO: use get_encoder function with support for HF and Megatron
        self.encoder = TransformerEncoderNM(
            vocab_size=self.encoder_vocab_size,
            hidden_size=cfg.encoder.hidden_size,
            num_layers=cfg.encoder.num_layers,
            inner_size=cfg.encoder.inner_size,
            max_sequence_length=cfg.encoder.max_sequence_length
            if hasattr(cfg.encoder, 'max_sequence_length')
            else 512,
            embedding_dropout=cfg.encoder.embedding_dropout if hasattr(cfg.encoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.encoder.learn_positional_encodings
            if hasattr(cfg.encoder, 'learn_positional_encodings')
            else False,
            num_attention_heads=cfg.encoder.num_attention_heads,
            ffn_dropout=cfg.encoder.ffn_dropout,
            attn_score_dropout=cfg.encoder.attn_score_dropout,
            attn_layer_dropout=cfg.encoder.attn_layer_dropout,
            hidden_act=cfg.encoder.hidden_act,
            mask_future=cfg.encoder.mask_future,
            pre_ln=cfg.encoder.pre_ln,
        )

        # TODO: user get_decoder function with support for HF and Megatron
        self.decoder = TransformerDecoderNM(
            vocab_size=self.decoder_vocab_size,
            hidden_size=cfg.decoder.hidden_size,
            num_layers=cfg.decoder.num_layers,
            inner_size=cfg.decoder.inner_size,
            max_sequence_length=cfg.decoder.max_sequence_length
            if hasattr(cfg.decoder, 'max_sequence_length')
            else 512,
            embedding_dropout=cfg.decoder.embedding_dropout if hasattr(cfg.decoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.decoder.learn_positional_encodings
            if hasattr(cfg.decoder, 'learn_positional_encodings')
            else False,
            num_attention_heads=cfg.decoder.num_attention_heads,
            ffn_dropout=cfg.decoder.ffn_dropout,
            attn_score_dropout=cfg.decoder.attn_score_dropout,
            attn_layer_dropout=cfg.decoder.attn_layer_dropout,
            hidden_act=cfg.decoder.hidden_act,
            pre_ln=cfg.decoder.pre_ln,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size ** 0.5
        self.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing
        )
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True)

    def filter_predicted_ids(self, ids):
        ids[ids >= self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id
        return ids

    @typecheck()
    def forward(self, src, src_mask, tgt, tgt_mask):
        src_hiddens = self.encoder(src, src_mask)
        tgt_hiddens = self.decoder(tgt, tgt_mask, src_hiddens, src_mask)
        log_probs = self.log_softmax(hidden_states=tgt_hiddens)
        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)
        train_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        tensorboard_logs = {
            'train_loss': train_loss,
            'lr': self._optimizer.param_groups[0]['lr'],
        }
        return {'loss': train_loss, 'log': tensorboard_logs}

    def eval_step(self, batch, batch_idx, mode):
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)

        src_hiddens = self.encoder(src_ids, src_mask)

        beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask)
        beam_results = self.filter_predicted_ids(beam_results)

        eval_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        self.eval_loss(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1])
        translations = [self.decoder_tokenizer.ids_to_text(tr) for tr in beam_results.cpu().numpy()]
        np_tgt = tgt_ids.cpu().numpy()
        ground_truths = [self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt]
        num_non_pad_tokens = np.not_equal(np_tgt, self.decoder_tokenizer.pad_id).sum().item()
        return {
            'translations': translations,
            'ground_truths': ground_truths,
            'num_non_pad_tokens': num_non_pad_tokens,
        }

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, 'test')

    @rank_zero_only
    def log_param_stats(self):
        for name, p in self.named_parameters():
            if p.requires_grad:
                self.trainer.logger.experiment.add_histogram(name + '_hist', p, global_step=self.global_step)
                self.trainer.logger.experiment.add_scalars(
                    name,
                    {'mean': p.mean(), 'stddev': p.std(), 'max': p.max(), 'min': p.min()},
                    global_step=self.global_step,
                )

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        return self.eval_step(batch, batch_idx, 'val')

    def eval_epoch_end(self, outputs, mode):
        eval_loss = self.eval_loss.compute()
        translations = list(itertools.chain(*[x['translations'] for x in outputs]))
        ground_truths = list(itertools.chain(*[x['ground_truths'] for x in outputs]))

        # TODO: add target language so detokenizer can be lang specific.
        detokenizer = MosesDetokenizer(lang=self.tgt_language)
        translations = [detokenizer.detokenize(sent.split()) for sent in translations]
        ground_truths = [detokenizer.detokenize(sent.split()) for sent in ground_truths]
        if self.tgt_language in ['ja']:
            sp_detokenizer = SentencePieceDetokenizer()
            translations = [sp_detokenizer.detokenize(sent.split()) for sent in translations]
            ground_truths = [sp_detokenizer.detokenize(sent.split()) for sent in ground_truths]

        assert len(translations) == len(ground_truths)
        if self.tgt_language in ['ja']:
            sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="ja-mecab")
        else:
            sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="13a")

        dataset_name = "Validation" if mode == 'val' else "Test"
        logging.info(f"\n\n\n\n{dataset_name} set size: {len(translations)}")
        logging.info(f"{dataset_name} Sacre BLEU = {sacre_bleu.score}")
        logging.info(f"{dataset_name} TRANSLATION EXAMPLES:".upper())
        for i in range(0, 3):
            ind = random.randint(0, len(translations) - 1)
            logging.info("    " + '\u0332'.join(f"EXAMPLE {i}:"))
            logging.info(f"    Prediction:   {translations[ind]}")
            logging.info(f"    Ground Truth: {ground_truths[ind]}")

        ans = {f"{mode}_loss": eval_loss, f"{mode}_sacreBLEU": sacre_bleu.score}
        ans['log'] = dict(ans)
        return ans

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        self.log_dict(self.eval_epoch_end(outputs, 'val'))

    def test_epoch_end(self, outputs):
        return self.eval_epoch_end(outputs, 'test')

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("load_from_cached_dataset", False):
            logging.info('Loading from cached dataset %s' % (cfg.src_file_name))
            if cfg.src_file_name != cfg.tgt_file_name:
                raise ValueError("src must be equal to target for cached dataset")
            dataset = pickle.load(open(cfg.src_file_name, 'rb'))
            dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False)
        elif cfg.get("load_from_tarred_dataset", False):
            logging.info('Loading from tarred dataset %s' % (cfg.src_file_name))
            if cfg.src_file_name != cfg.tgt_file_name:
                raise ValueError("src must be equal to target for tarred dataset")
            if cfg.get("metadata_path", None) is None:
                raise FileNotFoundError("Could not find metadata path in config")
            dataset = TarredTranslationDataset(
                text_tar_filepaths=cfg.src_file_name,
                metadata_path=cfg.metadata_path,
                encoder_tokenizer=self.encoder_tokenizer,
                decoder_tokenizer=self.decoder_tokenizer,
                shuffle_n=cfg.get("tar_shuffle_n", 100),
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
                reverse_lang_direction=cfg.get("reverse_lang_direction", False),
            )
            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            dataset = TranslationDataset(
                dataset_src=str(Path(cfg.src_file_name).expanduser()),
                dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction", False),
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    @torch.no_grad()
    def translate(self, text: List[str], source_lang: str = None, target_lang: str = None) -> List[str]:
        """
        Translates list of sentences from source language to target language.
        Should be regular text, this method performs its own tokenization/de-tokenization
        Args:
            text: list of strings to translate
            source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run
            target_lang: if not None, corresponding MosesDecokenizer will be run
        Returns:
            list of translated strings
        """
        if source_lang is None:
            source_lang = self.src_language
        if target_lang is None:
            target_lang = self.tgt_language

        mode = self.training
        tokenizer = MosesTokenizer(lang=source_lang)
        normalizer = MosesPunctNormalizer(lang=source_lang)
        detokenizer = MosesDetokenizer(lang=target_lang)

        try:
            self.eval()
            res = []
            for txt in text:
                if source_lang != "None":
                    txt = normalizer.normalize(txt)
                    txt = tokenizer.tokenize(txt, escape=False, return_str=True)
                ids = self.encoder_tokenizer.text_to_ids(txt)
                ids = [self.encoder_tokenizer.bos_id] + ids + [self.encoder_tokenizer.eos_id]
                src = torch.Tensor(ids).long().to(self._device).unsqueeze(0)
                src_mask = torch.ones_like(src)
                src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
                beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask)
                beam_results = self.filter_predicted_ids(beam_results)
                translation_ids = beam_results.cpu()[0].numpy()
                translation = self.decoder_tokenizer.ids_to_text(translation_ids)
                translation = detokenizer.detokenize(translation.split())
                if target_lang in ["ja"]:
                    sp_detokenizer = SentencePieceDetokenizer()
                    translation = sp_detokenizer.detokenize(translation.split())

                res.append(translation)
        finally:
            self.train(mode=mode)
        return res

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass

    def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
        logging.info(f'overriding ddp to set find_unused_parameters to {self._cfg.find_unused_parameters}')
        model = LightningDistributedDataParallel(
            model, device_ids=device_ids, find_unused_parameters=self._cfg.find_unused_parameters
        )
        return model

    def setup(self, stage):
        if stage == "fit":
            # Update PTL trainer to use our configure_ddp
            self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp
Exemple #9
0
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_library=cfg.encoder_tokenizer.get('library', 'yttm'),
            encoder_tokenizer_model=cfg.encoder_tokenizer.get('tokenizer_model'),
            encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
            encoder_model_name=cfg.encoder.get('model_name') if hasattr(cfg.encoder, 'model_name') else None,
            decoder_tokenizer_library=cfg.decoder_tokenizer.get('library', 'yttm'),
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
            decoder_model_name=cfg.decoder.get('model_name') if hasattr(cfg.decoder, 'model_name') else None,
        )

        # After this call, the model will have  self.source_processor and self.target_processor objects
        self.setup_pre_and_post_processing_utils(source_lang=self.src_language, target_lang=self.tgt_language)

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library, model_name=model_name, pretrained=pretrained, config_dict=encoder_cfg_dict, encoder=True,
        )

        # decoder from NeMo, Megatron-LM, or HuggingFace
        decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder'))
        decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size
        library = decoder_cfg_dict.pop('library', 'nemo')
        model_name = decoder_cfg_dict.pop('model_name', None)
        pretrained = decoder_cfg_dict.pop('pretrained', False)
        decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
        self.decoder = get_transformer(
            library=library, model_name=model_name, pretrained=pretrained, config_dict=decoder_cfg_dict, encoder=False,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size ** 0.5

        # initialize weights if not using pretrained encoder/decoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(module, std_init_range))

        if not self._cfg.decoder.get('pretrained', False):
            self.decoder.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing
        )
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True)