コード例 #1
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
コード例 #2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file)
            if cfg.tokenizer is not None else None,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.token_classifier.num_classes,
            num_layers=cfg.token_classifier.num_layers,
            activation=cfg.token_classifier.activation,
            log_softmax=cfg.token_classifier.log_softmax,
            dropout=cfg.token_classifier.dropout,
            use_transformer_init=cfg.token_classifier.use_transformer_init,
        )

        self.loss = SpanningLoss()
コード例 #3
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self.class_weights = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.classifier = TokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
コード例 #4
0
ファイル: qa_model.py プロジェクト: sycomix/NeMo
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)
        self.classifier = TokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=cfg.token_classifier.num_classes,
            num_layers=cfg.token_classifier.num_layers,
            activation=cfg.token_classifier.activation,
            log_softmax=cfg.token_classifier.log_softmax,
            dropout=cfg.token_classifier.dropout,
            use_transformer_init=cfg.token_classifier.use_transformer_init,
        )

        self.loss = SpanningLoss()
コード例 #5
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        self.tokenizer = get_tokenizer(
            tokenizer_name=cfg.language_model.tokenizer,
            vocab_file=cfg.language_model.vocab_file,
            special_tokens=cfg.language_model.special_tokens,
        )

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        self.embedding_layer = TransformerEmbedding(
            vocab_size=vocab_size,
            hidden_size=cfg.language_model.hidden_size,
            max_sequence_length=cfg.language_model.max_seq_length,
            embedding_dropout=cfg.language_model.get("embedding_dropout", 0.0),
            learn_positional_encodings=False,
        )
        self.encoder = TransformerEncoder(
            num_layers=cfg.language_model.num_layers,
            hidden_size=cfg.language_model.hidden_size,
            mask_future=True,
            num_attention_heads=cfg.language_model.num_attn_heads,
            inner_size=cfg.language_model.inner_size,
            ffn_dropout=cfg.language_model.get("ffn_dropout", 0.0),
            hidden_act=cfg.language_model.get("inner_activation", "relu"),
            attn_score_dropout=cfg.language_model.get("attn_score_dropout", 0.0),
            attn_layer_dropout=cfg.language_model.get("attn_layer_dropout", 0.0),
        )
        self.log_softmax = TokenClassifier(
            hidden_size=cfg.language_model.hidden_size, num_classes=vocab_size, log_softmax=True,
        )

        std_init_range = 1 / math.sqrt(cfg.language_model.hidden_size)
        self.apply(lambda module: transformer_weights_init(module, std_init_range))

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.embedding_layer.token_embedding.weight

        self.training_loss = SmoothedCrossEntropyLoss(pad_id=self.tokenizer.pad_id)
        self.validation_loss = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id, predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)
コード例 #6
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""

        self._setup_tokenizer(cfg.tokenizer)

        self._cfg = cfg
        self.data_desc = None
        self.update_data_dir(cfg.dataset.data_dir)
        self.setup_loss(class_balancing=self._cfg.dataset.class_balancing)

        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=self._cfg.head.log_softmax,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)
        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
コード例 #7
0
ファイル: transformer_lm_model.py プロジェクト: ggrunin/NeMo
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Instantiates tokenizer and register to be saved with NeMo Model archive
        # After this call, ther will be self.tokenizer which can convert between tokens and token_ids.
        self.setup_tokenizer(
            tokenizer_name=cfg.tokenizer.get("tokenizer_name", "yttm"),
            tokenizer_model=cfg.tokenizer.get("tokenizer_model", None),
            vocab_file=cfg.tokenizer.get("vocab_file", None),
            bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0),
            special_tokens=cfg.tokenizer.get("special_tokens", {}))

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm',
                encoder_cfg_dict.get('pre_ln', True)),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.encoder.hidden_size,
            num_classes=vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight

        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id, label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
        self.eval_ppl = SequencePerplexity()
コード例 #8
0
ファイル: transformer_lm_model.py プロジェクト: ggrunin/NeMo
class TransformerLMModel(ModelPT):
    """
    Left-to-right Transformer language model.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Instantiates tokenizer and register to be saved with NeMo Model archive
        # After this call, ther will be self.tokenizer which can convert between tokens and token_ids.
        self.setup_tokenizer(
            tokenizer_name=cfg.tokenizer.get("tokenizer_name", "yttm"),
            tokenizer_model=cfg.tokenizer.get("tokenizer_model", None),
            vocab_file=cfg.tokenizer.get("vocab_file", None),
            bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0),
            special_tokens=cfg.tokenizer.get("special_tokens", {}))

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm',
                encoder_cfg_dict.get('pre_ln', True)),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.encoder.hidden_size,
            num_classes=vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight

        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id, label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
        self.eval_ppl = SequencePerplexity()

    @typecheck()
    def forward(self, input_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """

        hidden_states = self.encoder(input_ids=input_ids,
                                     encoder_mask=attention_mask)
        log_probs = self.log_softmax(hidden_states=hidden_states)

        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1
                # added by DataLoader is excess.
                batch[i] = batch[i].squeeze(dim=0)
        ids, mask = batch
        input_ids, labels = ids[:, :-1], ids[:, 1:]
        input_mask = mask[:, :-1]
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)

        train_loss = self.loss_fn(log_probs=log_probs, labels=labels)

        tensorboard_logs = {
            "train_loss": train_loss,
            "lr": self._optimizer.param_groups[0]["lr"],
        }
        return {"loss": train_loss, "log": tensorboard_logs}

    def eval_step(self, batch, batch_idx):
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1
                # added by DataLoader is excess.
                batch[i] = batch[i].squeeze(dim=0)
        ids, mask = batch
        input_ids, labels = ids[:, :-1], ids[:, 1:]
        input_mask, output_mask = mask[:, :-1], mask[:, 1:]
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)
        eval_loss = self.eval_loss_fn(log_probs=log_probs, labels=labels)

        self.eval_loss(loss=eval_loss,
                       num_measurements=log_probs.shape[0] *
                       log_probs.shape[1])
        self.eval_ppl(log_probs=log_probs, labels=labels, mask=output_mask)
        return {}

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        return self.eval_step(batch, batch_idx)

    def eval_epoch_end(self, outputs, mode):
        eval_loss = self.eval_loss.compute()
        eval_ppl = self.eval_ppl.compute()
        self.log(f"{mode}_loss", eval_loss, sync_dist=True)
        self.log(f"{mode}_PPL", eval_ppl, sync_dist=True)
        dataset_name = "Validation" if mode == 'val' else "Test"
        logging.info(
            f"\n\n\n\n{dataset_name} PPL: {np.round(eval_ppl.item(), 2)}")

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        self.eval_epoch_end(outputs, 'val')
        self.eval_loss.reset()
        self.eval_ppl.reset()

    def test_epoch_end(self, outputs):
        self.eval_epoch_end(outputs, 'test')

    def setup_tokenizer(self,
                        tokenizer_name=None,
                        tokenizer_model=None,
                        vocab_file=None,
                        bpe_dropout=0.0,
                        special_tokens=None):

        supported_tokenizers = ['yttm', 'huggingface', 'sentencepiece', 'word']
        if tokenizer_name not in supported_tokenizers:
            raise NotImplementedError(
                f"Currently we only support tokenizers in {supported_tokenizers}."
            )

        self.tokenizer = get_tokenizer(
            tokenizer_name=tokenizer_name,
            tokenizer_model=self.register_artifact(
                "cfg.tokenizer.tokenizer_model", tokenizer_model),
            vocab_file=vocab_file,
            bpe_dropout=bpe_dropout,
            special_tokens=dict(special_tokens or {}),
            use_fast=False,
        )

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0):

        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            else:
                metadata_file = cfg.get('metadata_file')
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if cfg.get('tar_files') is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                    else:
                        raise FileNotFoundError(
                            "Could not find tarred dataset in config or metadata."
                        )
                else:
                    tar_files = cfg.get('tar_files')
                    if metadata.get('tar_files') is not None:
                        raise ValueError(
                            'Tar files specified in config and in metadata file. Tar files should only be specified once.'
                        )
            dataset = TarredSentenceDataset(
                text_tar_filepaths=tar_files,
                metadata_path=metadata_file,
                tokenizer=self.tokenizer,
                shuffle_n=cfg.get("tar_shuffle_n", 100),
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            dataset = SentenceDataset(
                tokenizer=self.tokenizer,
                dataset=cfg.file_name,
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                cache_ids=cfg.get("cache_ids", False),
            )
        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
コード例 #9
0
 def test_bert_pretraining_export_to_onnx(self):
     for num_layers in [1, 2, 4]:
         classifier_export(TokenClassifier(hidden_size=256, num_layers=num_layers, num_classes=16))
コード例 #10
0
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_name=cfg.encoder_tokenizer.tokenizer_name,
            encoder_tokenizer_model=cfg.encoder_tokenizer.tokenizer_model,
            encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
            decoder_tokenizer_name=cfg.decoder_tokenizer.tokenizer_name,
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
        )

        # After this call, the model will have  self.source_processor and self.target_processor objects
        self.setup_pre_and_post_processing_utils(source_lang=self.src_language,
                                                 target_lang=self.tgt_language)

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # TODO: use get_encoder function with support for HF and Megatron
        self.encoder = TransformerEncoderNM(
            vocab_size=self.encoder_vocab_size,
            hidden_size=cfg.encoder.hidden_size,
            num_layers=cfg.encoder.num_layers,
            inner_size=cfg.encoder.inner_size,
            max_sequence_length=cfg.encoder.max_sequence_length if hasattr(
                cfg.encoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.encoder.embedding_dropout if hasattr(
                cfg.encoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.encoder.learn_positional_encodings
            if hasattr(cfg.encoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.encoder.num_attention_heads,
            ffn_dropout=cfg.encoder.ffn_dropout,
            attn_score_dropout=cfg.encoder.attn_score_dropout,
            attn_layer_dropout=cfg.encoder.attn_layer_dropout,
            hidden_act=cfg.encoder.hidden_act,
            mask_future=cfg.encoder.mask_future,
            pre_ln=cfg.encoder.pre_ln,
        )

        # TODO: user get_decoder function with support for HF and Megatron
        self.decoder = TransformerDecoderNM(
            vocab_size=self.decoder_vocab_size,
            hidden_size=cfg.decoder.hidden_size,
            num_layers=cfg.decoder.num_layers,
            inner_size=cfg.decoder.inner_size,
            max_sequence_length=cfg.decoder.max_sequence_length if hasattr(
                cfg.decoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.decoder.embedding_dropout if hasattr(
                cfg.decoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.decoder.learn_positional_encodings
            if hasattr(cfg.decoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.decoder.num_attention_heads,
            ffn_dropout=cfg.decoder.ffn_dropout,
            attn_score_dropout=cfg.decoder.attn_score_dropout,
            attn_layer_dropout=cfg.decoder.attn_layer_dropout,
            hidden_act=cfg.decoder.hidden_act,
            pre_ln=cfg.decoder.pre_ln,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
コード例 #11
0
class QAModel(NLPModel, Exportable):
    """
    BERT encoder with QA head training.
    """
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.token_classifier.num_classes,
            num_layers=cfg.token_classifier.num_layers,
            activation=cfg.token_classifier.activation,
            log_softmax=cfg.token_classifier.log_softmax,
            dropout=cfg.token_classifier.dropout,
            use_transformer_init=cfg.token_classifier.use_transformer_init,
        )

        self.loss = SpanningLoss()

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        input_ids, input_type_ids, input_mask, unique_ids, start_positions, end_positions = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)
        loss, _, _ = self.loss(logits=logits,
                               start_positions=start_positions,
                               end_positions=end_positions)

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)
        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx):
        if self.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        input_ids, input_type_ids, input_mask, unique_ids, start_positions, end_positions = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)
        loss, start_logits, end_logits = self.loss(
            logits=logits,
            start_positions=start_positions,
            end_positions=end_positions)

        tensors = {
            'unique_ids': unique_ids,
            'start_logits': start_logits,
            'end_logits': end_logits,
        }
        self.log(f'{prefix}_loss', loss)
        return {f'{prefix}_loss': loss, f'{prefix}_tensors': tensors}

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        if self.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        avg_loss = torch.stack([x[f'{prefix}_loss'] for x in outputs]).mean()

        unique_ids = torch.cat(
            [x[f'{prefix}_tensors']['unique_ids'] for x in outputs])
        start_logits = torch.cat(
            [x[f'{prefix}_tensors']['start_logits'] for x in outputs])
        end_logits = torch.cat(
            [x[f'{prefix}_tensors']['end_logits'] for x in outputs])

        all_unique_ids = []
        all_start_logits = []
        all_end_logits = []
        if torch.distributed.is_initialized():
            world_size = torch.distributed.get_world_size()
            for ind in range(world_size):
                all_unique_ids.append(torch.empty_like(unique_ids))
                all_start_logits.append(torch.empty_like(start_logits))
                all_end_logits.append(torch.empty_like(end_logits))
            torch.distributed.all_gather(all_unique_ids, unique_ids)
            torch.distributed.all_gather(all_start_logits, start_logits)
            torch.distributed.all_gather(all_end_logits, end_logits)
        else:
            all_unique_ids.append(unique_ids)
            all_start_logits.append(start_logits)
            all_end_logits.append(end_logits)

        exact_match, f1, all_predictions, all_nbest = -1, -1, [], []
        if not torch.distributed.is_initialized(
        ) or torch.distributed.get_rank() == 0:

            unique_ids = []
            start_logits = []
            end_logits = []
            for u in all_unique_ids:
                unique_ids.extend(tensor2list(u))
            for u in all_start_logits:
                start_logits.extend(tensor2list(u))
            for u in all_end_logits:
                end_logits.extend(tensor2list(u))

            eval_dataset = self._test_dl.dataset if self.testing else self._validation_dl.dataset
            exact_match, f1, all_predictions, all_nbest = eval_dataset.evaluate(
                unique_ids=unique_ids,
                start_logits=start_logits,
                end_logits=end_logits,
                n_best_size=self._cfg.dataset.n_best_size,
                max_answer_length=self._cfg.dataset.max_answer_length,
                version_2_with_negative=self._cfg.dataset.
                version_2_with_negative,
                null_score_diff_threshold=self._cfg.dataset.
                null_score_diff_threshold,
                do_lower_case=self._cfg.dataset.do_lower_case,
            )

        logging.info(f"{prefix} exact match {exact_match}")
        logging.info(f"{prefix} f1 {f1}")

        self.log(f'{prefix}_loss', avg_loss)
        self.log(f'{prefix}_exact_match', exact_match)
        self.log(f'{prefix}_f1', f1)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs)

    @torch.no_grad()
    def inference(
        self,
        file: str,
        batch_size: int = 1,
        num_samples: int = -1,
        output_nbest_file: Optional[str] = None,
        output_prediction_file: Optional[str] = None,
    ):
        """
        Get prediction for unlabeled inference data
        Args:
            file: inference data
            batch_size: batch size to use during inference
            num_samples: number of samples to use of inference data. Default: -1 if all data should be used.
            output_nbest_file: optional output file for writing out nbest list
            output_prediction_file: optional output file for writing out predictions
        Returns:
            all_predictions: model predictions
            all_nbest: model nbest list
        """
        # store predictions for all queries in a single list
        all_predictions = []
        all_nbest = []
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {
                "batch_size": batch_size,
                "file": file,
                "shuffle": False,
                "num_samples": num_samples,
                'num_workers': 2,
                'pin_memory': False,
                'drop_last': False,
            }
            dataloader_cfg = OmegaConf.create(dataloader_cfg)
            infer_datalayer = self._setup_dataloader_from_config(
                cfg=dataloader_cfg, mode=INFERENCE_MODE)

            all_logits = []
            all_unique_ids = []
            for i, batch in enumerate(infer_datalayer):
                input_ids, token_type_ids, attention_mask, unique_ids = batch
                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=token_type_ids.to(device),
                    attention_mask=attention_mask.to(device),
                )
                all_logits.append(logits)
                all_unique_ids.append(unique_ids)
            logits = torch.cat(all_logits)
            unique_ids = tensor2list(torch.cat(all_unique_ids))
            s, e = logits.split(dim=-1, split_size=1)
            start_logits = tensor2list(s.squeeze())
            end_logits = tensor2list(e.squeeze())
            (all_predictions, all_nbest,
             scores_diff) = infer_datalayer.dataset.get_predictions(
                 unique_ids=unique_ids,
                 start_logits=start_logits,
                 end_logits=end_logits,
                 n_best_size=self._cfg.dataset.n_best_size,
                 max_answer_length=self._cfg.dataset.max_answer_length,
                 version_2_with_negative=self._cfg.dataset.
                 version_2_with_negative,
                 null_score_diff_threshold=self._cfg.dataset.
                 null_score_diff_threshold,
                 do_lower_case=self._cfg.dataset.do_lower_case,
             )

            with open(file, 'r') as test_file_fp:
                test_data = json.load(test_file_fp)["data"]
                id_to_question_mapping = {}
                for title in test_data:
                    for par in title["paragraphs"]:
                        for question in par["qas"]:
                            id_to_question_mapping[
                                question["id"]] = question["question"]

            for question_id in all_predictions:
                all_predictions[question_id] = (
                    id_to_question_mapping[question_id],
                    all_predictions[question_id])

            if output_nbest_file is not None:
                with open(output_nbest_file, "w") as writer:
                    writer.write(json.dumps(all_nbest, indent=4) + "\n")
            if output_prediction_file is not None:
                with open(output_prediction_file, "w") as writer:
                    writer.write(json.dumps(all_predictions, indent=4) + "\n")

        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)

        return all_predictions, all_nbest

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.file:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, mode=TRAINING_MODE)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.file:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, mode=EVALUATION_MODE)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or test_data_config.file is None:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, mode=EVALUATION_MODE)

    def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str):
        dataset = SquadDataset(
            tokenizer=self.tokenizer,
            data_file=cfg.file,
            doc_stride=self._cfg.dataset.doc_stride,
            max_query_length=self._cfg.dataset.max_query_length,
            max_seq_length=self._cfg.dataset.max_seq_length,
            version_2_with_negative=self._cfg.dataset.version_2_with_negative,
            num_samples=cfg.num_samples,
            mode=mode,
            use_cache=self._cfg.dataset.use_cache,
        )

        dl = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            collate_fn=dataset.collate_fn,
            drop_last=cfg.drop_last,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
        )
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="BERTBaseUncasedSQuADv1.1",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTBaseUncasedSQuADv1.1.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Base Uncased on SQuAD v1.1 dataset which obtains an exact match (EM) score of 82.43% and an F1 score of 89.59%.",
        )
        result.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="BERTBaseUncasedSQuADv2.0",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTBaseUncasedSQuADv2.0.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Base Uncased on SQuAD v2.0 dataset which obtains an exact match (EM) score of 73.35% and an F1 score of 76.44%.",
        )
        result.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="BERTLargeUncasedSQuADv1.1",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTLargeUncasedSQuADv1.1.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Large Uncased on SQuAD v1.1 dataset which obtains an exact match (EM) score of 85.47% and an F1 score of 92.10%.",
        )
        result.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="BERTLargeUncasedSQuADv2.0",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTLargeUncasedSQuADv2.0.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Large Uncased on SQuAD v2.0 dataset which obtains an exact match (EM) score of 78.8% and an F1 score of 81.85%.",
        )
        result.append(model)
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " QAModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "QA")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])
コード例 #12
0
class PunctuationCapitalizationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
            "capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        capit_logits = self.capit_classifier(hidden_states=hidden_states)
        return punct_logits, capit_logits

    def _make_step(self, batch):
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
        punct_logits, capit_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
        capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
        return loss, punct_logits, capit_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        val_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        test_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory

        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        """Setup training data"""
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
            )

        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            self.register_artifact(
                self._cfg.class_labels.punct_labels_file, self._train_dl.dataset.punct_label_ids_file
            )
            self.register_artifact(
                self._cfg.class_labels.capit_labels_file, self._train_dl.dataset.capit_label_ids_file
            )

            # save label maps to the config
            self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
            self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)

    def setup_validation_data(self, val_data_config: Optional[Dict] = None):
        """
        Setup validaton data

        val_data_config: validation data config
        """
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        # use data_dir specified in the ds_item to run evaluation on multiple datasets
        if 'ds_item' in cfg and cfg.ds_item is not None:
            data_dir = cfg.ds_item
        else:
            data_dir = self._cfg.dataset.data_dir

        text_file = os.path.join(data_dir, cfg.text_file)
        label_file = os.path.join(data_dir, cfg.labels_file)

        dataset = BertPunctuationCapitalizationDataset(
            tokenizer=self.tokenizer,
            text_file=text_file,
            label_file=label_file,
            pad_label=self._cfg.dataset.pad_label,
            punct_label_ids=self._cfg.punct_label_ids,
            capit_label_ids=self._cfg.capit_label_ids,
            max_seq_length=self._cfg.dataset.max_seq_length,
            ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
            ignore_start_end=self._cfg.dataset.ignore_start_end,
            use_cache=self._cfg.dataset.use_cache,
            num_samples=cfg.num_samples,
            punct_label_ids_file=self._cfg.class_labels.punct_labels_file
            if 'class_labels' in self._cfg
            else 'punct_label_ids.csv',
            capit_label_ids_file=self._cfg.class_labels.capit_labels_file
            if 'class_labels' in self._cfg
            else 'capit_label_ids.csv',
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=self._cfg.dataset.drop_last,
        )

    def _setup_infer_dataloader(self, queries: List[str], batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertPunctuationCapitalizationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=self._cfg.dataset.max_seq_length
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    def add_punctuation_capitalization(self, queries: List[str], batch_size: int = None) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for debugging and prototyping.
        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
        Returns:
            result: text with added capitalization and punctuation
        """
        if queries is None or len(queries) == 0:
            return []
        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')

        # We will store the output here
        result = []

        # Model's mode and device
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self = self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # store predictions for all queries in a single list
            all_punct_preds = []
            all_capit_preds = []

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                punct_logits, capit_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                punct_preds = tensor2list(torch.argmax(punct_logits, axis=-1)[subtokens_mask])
                capit_preds = tensor2list(torch.argmax(capit_logits, axis=-1)[subtokens_mask])
                all_punct_preds.extend(punct_preds)
                all_capit_preds.extend(capit_preds)

            queries = [q.strip().split() for q in queries]
            queries_len = [len(q) for q in queries]

            if sum(queries_len) != len(all_punct_preds) or sum(queries_len) != len(all_capit_preds):
                raise ValueError('Pred and words must have the same length')

            punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
            capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}

            start_idx = 0
            end_idx = 0
            for query in queries:
                end_idx += len(query)

                # extract predictions for the current query from the list of all predictions
                punct_preds = all_punct_preds[start_idx:end_idx]
                capit_preds = all_capit_preds[start_idx:end_idx]
                start_idx = end_idx

                query_with_punct_and_capit = ''
                for j, word in enumerate(query):
                    punct_label = punct_ids_to_labels[punct_preds[j]]
                    capit_label = capit_ids_to_labels[capit_preds[j]]

                    if capit_label != self._cfg.dataset.pad_label:
                        word = word.capitalize()
                    query_with_punct_and_capit += word
                    if punct_label != self._cfg.dataset.pad_label:
                        query_with_punct_and_capit += punct_label
                    query_with_punct_and_capit += ' '

                result.append(query_with_punct_and_capit.strip())
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_BERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_BERT.nemo",
                description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_DistilBERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_DistilBERT.nemo",
                description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        """
        Unlike other models' export() this one creates 5 output files, not 3:
        punct_<output> - fused punctuation model (BERT+PunctuationClassifier)
        capit_<output> - fused capitalization model (BERT+CapitalizationClassifier)
        bert_<output> - common BERT neural net
        punct_classifier_<output> - Punctuation Classifier neural net
        capt_classifier_<output> - Capitalization Classifier neural net
        """
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " PunctuationCapitalizationModel consists of three separate models with different"
                " inputs and outputs."
            )

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output), 'punct_classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Punctuation Classifier exported to ONNX'
        punct_classifier_onnx = self.punct_classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output3 = os.path.join(os.path.dirname(output), 'capit_classifier_' + os.path.basename(output))
        output3_descr = qual_name + ' Capitalization Classifier exported to ONNX'
        capit_classifier_onnx = self.capit_classifier.export(
            output3,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        punct_output_model = attach_onnx_to_onnx(bert_model_onnx, punct_classifier_onnx, "PTCL")
        output4 = os.path.join(os.path.dirname(output), 'punct_' + os.path.basename(output))
        output4_descr = qual_name + ' Punctuation BERT+Classifier exported to ONNX'
        onnx.save(punct_output_model, output4)
        capit_output_model = attach_onnx_to_onnx(bert_model_onnx, capit_classifier_onnx, "CPCL")
        output5 = os.path.join(os.path.dirname(output), 'capit_' + os.path.basename(output))
        output5_descr = qual_name + ' Capitalization BERT+Classifier exported to ONNX'
        onnx.save(capit_output_model, output5)
        return (
            [output1, output2, output3, output4, output5],
            [output1_descr, output2_descr, output3_descr, output4_descr, output5_descr],
        )
コード例 #13
0
ファイル: mt_enc_dec_model.py プロジェクト: tiborkiss/NeMo
class MTEncDecModel(EncDecNLPModel):
    """
    Encoder-decoder machine translation model.
    """
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language = cfg.get("src_language", None)
        self.tgt_language = cfg.get("tgt_language", None)

        self.multilingual = cfg.get("multilingual", False)
        self.multilingual_ids = []

        self.encoder_tokenizer_library = cfg.encoder_tokenizer.get(
            'library', 'yttm')
        self.decoder_tokenizer_library = cfg.decoder_tokenizer.get(
            'library', 'yttm')

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_library=self.encoder_tokenizer_library,
            encoder_tokenizer_model=cfg.encoder_tokenizer.get(
                'tokenizer_model'),
            encoder_bpe_dropout=cfg.encoder_tokenizer.get(
                'bpe_dropout', 0.0) if cfg.encoder_tokenizer.get(
                    'bpe_dropout', 0.0) is not None else 0.0,
            encoder_model_name=cfg.encoder.get('model_name') if hasattr(
                cfg.encoder, 'model_name') else None,
            encoder_r2l=cfg.encoder_tokenizer.get('r2l', False),
            decoder_tokenizer_library=self.decoder_tokenizer_library,
            encoder_tokenizer_vocab_file=cfg.encoder_tokenizer.get(
                'vocab_file', None),
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get(
                'bpe_dropout', 0.0) if cfg.decoder_tokenizer.get(
                    'bpe_dropout', 0.0) is not None else 0.0,
            decoder_model_name=cfg.decoder.get('model_name') if hasattr(
                cfg.decoder, 'model_name') else None,
            decoder_r2l=cfg.decoder_tokenizer.get('r2l', False),
        )

        if self.multilingual:
            if isinstance(self.src_language, ListConfig) and isinstance(
                    self.tgt_language, ListConfig):
                raise ValueError(
                    "cfg.src_language and cfg.tgt_language cannot both be lists. We only support many-to-one or one-to-many multilingual models."
                )
            elif isinstance(self.src_language, ListConfig):
                for lng in self.src_language:
                    self.multilingual_ids.append(
                        self.encoder_tokenizer.token_to_id("<" + lng + ">"))
            elif isinstance(self.tgt_language, ListConfig):
                for lng in self.tgt_language:
                    self.multilingual_ids.append(
                        self.encoder_tokenizer.token_to_id("<" + lng + ">"))
            else:
                raise ValueError(
                    "Expect either cfg.src_language or cfg.tgt_language to be a list when multilingual=True."
                )

            if isinstance(self.src_language, ListConfig):
                self.tgt_language = [self.tgt_language] * len(
                    self.src_language)
            else:
                self.src_language = [self.src_language] * len(
                    self.tgt_language)

            self.source_processor_list = []
            self.target_processor_list = []
            for src_lng, tgt_lng in zip(self.src_language, self.tgt_language):
                src_prcsr, tgt_prscr = self.setup_pre_and_post_processing_utils(
                    src_lng, tgt_lng)
                self.source_processor_list.append(src_prcsr)
                self.target_processor_list.append(tgt_prscr)

        else:
            # After this call, the model will have  self.source_processor and self.target_processor objects
            self.setup_pre_and_post_processing_utils(self.src_language,
                                                     self.tgt_language)
            self.multilingual_ids = [None]

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        checkpoint_file = encoder_cfg_dict.pop('checkpoint_file', None)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm', False),
            checkpoint_file=checkpoint_file,
        )

        # decoder from NeMo, Megatron-LM, or HuggingFace
        decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder'))
        decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size
        library = decoder_cfg_dict.pop('library', 'nemo')
        model_name = decoder_cfg_dict.pop('model_name', None)
        pretrained = decoder_cfg_dict.pop('pretrained', False)
        decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
        self.decoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=decoder_cfg_dict,
            encoder=False,
            pre_ln_final_layer_norm=decoder_cfg_dict.get(
                'pre_ln_final_layer_norm', False),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder/decoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        if not self._cfg.decoder.get('pretrained', False):
            self.decoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = NLLLoss(ignore_index=self.decoder_tokenizer.pad_id)

    def filter_predicted_ids(self, ids):
        ids[ids >=
            self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id
        return ids

    @typecheck()
    def forward(self, src, src_mask, tgt, tgt_mask):
        src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
        tgt_hiddens = self.decoder(input_ids=tgt,
                                   decoder_mask=tgt_mask,
                                   encoder_embeddings=src_hiddens,
                                   encoder_mask=src_mask)
        log_probs = self.log_softmax(hidden_states=tgt_hiddens)
        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)
        train_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        tensorboard_logs = {
            'train_loss': train_loss,
            'lr': self._optimizer.param_groups[0]['lr'],
        }
        return {'loss': train_loss, 'log': tensorboard_logs}

    def eval_step(self, batch, batch_idx, mode, dataloader_idx=0):
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)

        if self.multilingual:
            self.source_processor = self.source_processor_list[dataloader_idx]
            self.target_processor = self.target_processor_list[dataloader_idx]

        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)
        eval_loss = self.eval_loss_fn(log_probs=log_probs, labels=labels)
        # this will run encoder twice -- TODO: potentially fix
        _, translations = self.batch_translate(src=src_ids, src_mask=src_mask)
        if dataloader_idx == 0:
            getattr(self, f'{mode}_loss')(loss=eval_loss,
                                          num_measurements=log_probs.shape[0] *
                                          log_probs.shape[1])
        else:
            getattr(self, f'{mode}_loss_{dataloader_idx}')(
                loss=eval_loss,
                num_measurements=log_probs.shape[0] * log_probs.shape[1])
        np_tgt = tgt_ids.detach().cpu().numpy()
        ground_truths = [
            self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt
        ]
        ground_truths = [
            self.target_processor.detokenize(tgt.split(' '))
            for tgt in ground_truths
        ]
        num_non_pad_tokens = np.not_equal(
            np_tgt, self.decoder_tokenizer.pad_id).sum().item()
        return {
            'translations': translations,
            'ground_truths': ground_truths,
            'num_non_pad_tokens': num_non_pad_tokens,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self.eval_step(batch, batch_idx, 'test', dataloader_idx)

    @rank_zero_only
    def log_param_stats(self):
        for name, p in self.named_parameters():
            if p.requires_grad:
                self.trainer.logger.experiment.add_histogram(
                    name + '_hist', p, global_step=self.global_step)
                self.trainer.logger.experiment.add_scalars(
                    name,
                    {
                        'mean': p.mean(),
                        'stddev': p.std(),
                        'max': p.max(),
                        'min': p.min()
                    },
                    global_step=self.global_step,
                )

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        return self.eval_step(batch, batch_idx, 'val', dataloader_idx)

    def eval_epoch_end(self, outputs, mode):
        # if user specifies one validation dataloader, then PTL reverts to giving a list of dictionary instead of a list of list of dictionary
        if isinstance(outputs[0], dict):
            outputs = [outputs]

        loss_list = []
        sb_score_list = []
        for dataloader_idx, output in enumerate(outputs):
            if dataloader_idx == 0:
                eval_loss = getattr(self, f'{mode}_loss').compute()
            else:
                eval_loss = getattr(self,
                                    f'{mode}_loss_{dataloader_idx}').compute()

            translations = list(
                itertools.chain(*[x['translations'] for x in output]))
            ground_truths = list(
                itertools.chain(*[x['ground_truths'] for x in output]))
            assert len(translations) == len(ground_truths)

            # Gather translations and ground truths from all workers
            tr_and_gt = [None for _ in range(self.world_size)]
            # we also need to drop pairs where ground truth is an empty string
            dist.all_gather_object(
                tr_and_gt, [(t, g)
                            for (t, g) in zip(translations, ground_truths)
                            if g.strip() != ''])
            if self.global_rank == 0:
                _translations = []
                _ground_truths = []
                for rank in range(0, self.world_size):
                    _translations += [t for (t, g) in tr_and_gt[rank]]
                    _ground_truths += [g for (t, g) in tr_and_gt[rank]]

                if self.tgt_language in ['ja']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="ja-mecab")
                elif self.tgt_language in ['zh']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="zh")
                else:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="13a")

                # because the reduction op later is average (over word_size)
                sb_score = sacre_bleu.score * self.world_size

                dataset_name = "Validation" if mode == 'val' else "Test"
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(translations)}"
                )
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Val Loss = {eval_loss}"
                )
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Sacre BLEU = {sb_score / self.world_size}"
                )
                logging.info(
                    f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:"
                )
                for i in range(0, 3):
                    ind = random.randint(0, len(translations) - 1)
                    logging.info("    " + '\u0332'.join(f"Example {i}:"))
                    logging.info(f"    Prediction:   {translations[ind]}")
                    logging.info(f"    Ground Truth: {ground_truths[ind]}")
            else:
                sb_score = 0.0

            loss_list.append(eval_loss.cpu().numpy())
            sb_score_list.append(sb_score)
            if dataloader_idx == 0:
                self.log(f"{mode}_loss", eval_loss, sync_dist=True)
                self.log(f"{mode}_sacreBLEU", sb_score, sync_dist=True)
                getattr(self, f'{mode}_loss').reset()
            else:
                self.log(f"{mode}_loss_dl_index_{dataloader_idx}",
                         eval_loss,
                         sync_dist=True)
                self.log(f"{mode}_sacreBLEU_dl_index_{dataloader_idx}",
                         sb_score,
                         sync_dist=True)
                getattr(self, f'{mode}_loss_{dataloader_idx}').reset()

        if len(loss_list) > 1:
            self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True)
            self.log(f"{mode}_sacreBLEU_avg",
                     np.mean(sb_score_list),
                     sync_dist=True)

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        self.eval_epoch_end(outputs, 'val')

    def test_epoch_end(self, outputs):
        self.eval_epoch_end(outputs, 'test')

    def setup_enc_dec_tokenizers(
        self,
        encoder_tokenizer_library=None,
        encoder_tokenizer_model=None,
        encoder_bpe_dropout=0.0,
        encoder_model_name=None,
        encoder_r2l=False,
        encoder_tokenizer_vocab_file=None,
        decoder_tokenizer_library=None,
        decoder_tokenizer_model=None,
        decoder_bpe_dropout=0.0,
        decoder_model_name=None,
        decoder_r2l=False,
    ):

        supported_tokenizers = [
            'yttm', 'huggingface', 'sentencepiece', 'megatron', 'byte-level'
        ]
        if (encoder_tokenizer_library not in supported_tokenizers
                or decoder_tokenizer_library not in supported_tokenizers):
            raise NotImplementedError(
                f"Currently we only support tokenizers in {supported_tokenizers}."
            )

        self.encoder_tokenizer = get_nmt_tokenizer(
            library=encoder_tokenizer_library,
            tokenizer_model=self.register_artifact(
                "encoder_tokenizer.tokenizer_model", encoder_tokenizer_model),
            bpe_dropout=encoder_bpe_dropout,
            model_name=encoder_model_name,
            vocab_file=self.register_artifact("encoder_tokenizer.vocab_file",
                                              encoder_tokenizer_vocab_file),
            special_tokens=None,
            use_fast=False,
            r2l=encoder_r2l,
        )
        self.decoder_tokenizer = get_nmt_tokenizer(
            library=decoder_tokenizer_library,
            tokenizer_model=self.register_artifact(
                "decoder_tokenizer.tokenizer_model", decoder_tokenizer_model),
            bpe_dropout=decoder_bpe_dropout,
            model_name=decoder_model_name,
            vocab_file=None,
            special_tokens=None,
            use_fast=False,
            r2l=decoder_r2l,
        )

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_multiple_validation_data(self, val_data_config: Union[DictConfig,
                                                                    Dict]):
        self.setup_validation_data(self._cfg.get('validation_ds'))

    def setup_multiple_test_data(self, test_data_config: Union[DictConfig,
                                                               Dict]):
        self.setup_test_data(self._cfg.get('test_ds'))

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_eval_dataloader_from_config(
            cfg=val_data_config)
        # instantiate Torchmetric for each val dataloader
        if self._validation_dl is not None:
            for dataloader_idx in range(len(self._validation_dl)):
                if dataloader_idx == 0:
                    setattr(
                        self,
                        f'val_loss',
                        GlobalAverageLossMetric(dist_sync_on_step=False,
                                                take_avg_loss=True),
                    )
                else:
                    setattr(
                        self,
                        f'val_loss_{dataloader_idx}',
                        GlobalAverageLossMetric(dist_sync_on_step=False,
                                                take_avg_loss=True),
                    )

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_eval_dataloader_from_config(
            cfg=test_data_config)
        # instantiate Torchmetric for each test dataloader
        if self._test_dl is not None:
            for dataloader_idx in range(len(self._test_dl)):
                if dataloader_idx == 0:
                    setattr(
                        self,
                        f'test_loss',
                        GlobalAverageLossMetric(dist_sync_on_step=False,
                                                take_avg_loss=True),
                    )
                else:
                    setattr(
                        self,
                        f'test_loss_{dataloader_idx}',
                        GlobalAverageLossMetric(dist_sync_on_step=False,
                                                take_avg_loss=True),
                    )

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            metadata_file_list = cfg.get('metadata_file')
            tar_files_list = cfg.get('tar_files', None)
            if isinstance(metadata_file_list, str):
                metadata_file_list = [metadata_file_list]
            if tar_files_list is not None and isinstance(tar_files_list, str):
                tar_files_list = [tar_files_list]
            if tar_files_list is not None and len(tar_files_list) != len(
                    metadata_file_list):
                raise ValueError(
                    'The config must have the same number of tarfile paths and metadata file paths.'
                )

            datasets = []
            for idx, metadata_file in enumerate(metadata_file_list):
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if tar_files_list is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                else:
                    tar_files = tar_files_list[idx]
                    if metadata.get('tar_files') is not None:
                        logging.info(
                            f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}'
                        )

                dataset = TarredTranslationDataset(
                    text_tar_filepaths=tar_files,
                    metadata_path=metadata_file,
                    encoder_tokenizer=self.encoder_tokenizer,
                    decoder_tokenizer=self.decoder_tokenizer,
                    shuffle_n=cfg.get("tar_shuffle_n", 100),
                    shard_strategy=cfg.get("shard_strategy", "scatter"),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]
        else:
            src_file_list = cfg.src_file_name
            tgt_file_list = cfg.tgt_file_name
            if isinstance(src_file_list, str):
                src_file_list = [src_file_list]
            if isinstance(tgt_file_list, str):
                tgt_file_list = [tgt_file_list]

            if len(src_file_list) != len(tgt_file_list):
                raise ValueError(
                    'The same number of filepaths must be passed in for source and target.'
                )

            datasets = []
            for idx, src_file in enumerate(src_file_list):
                dataset = TranslationDataset(
                    dataset_src=str(Path(src_file).expanduser()),
                    dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                    tokens_in_batch=cfg.tokens_in_batch,
                    clean=cfg.get("clean", False),
                    max_seq_length=cfg.get("max_seq_length", 512),
                    min_seq_length=cfg.get("min_seq_length", 1),
                    max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                    max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                    cache_ids=cfg.get("cache_ids", False),
                    cache_data_per_node=cfg.get("cache_data_per_node", False),
                    use_cache=cfg.get("use_cache", False),
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                dataset.batchify(self.encoder_tokenizer,
                                 self.decoder_tokenizer)
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    shuffle=cfg.get('shuffle'),
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]

        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=None if cfg.get("use_tarred_dataset", False) else sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    def replace_beam_with_sampling(self, topk=500):
        self.beam_search = TopKSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.beam_search.max_seq_length,
            beam_size=topk,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
        )

    def _setup_eval_dataloader_from_config(self, cfg: DictConfig):
        src_file_name = cfg.get('src_file_name')
        tgt_file_name = cfg.get('tgt_file_name')

        if src_file_name is None or tgt_file_name is None:
            raise ValueError(
                'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.'
            )
        else:
            # convert src_file_name and tgt_file_name to list of strings
            if isinstance(src_file_name, str):
                src_file_list = [src_file_name]
            elif isinstance(src_file_name, ListConfig):
                src_file_list = src_file_name
            else:
                raise ValueError(
                    "cfg.src_file_name must be string or list of strings")
            if isinstance(tgt_file_name, str):
                tgt_file_list = [tgt_file_name]
            elif isinstance(tgt_file_name, ListConfig):
                tgt_file_list = tgt_file_name
            else:
                raise ValueError(
                    "cfg.tgt_file_name must be string or list of strings")
        if len(src_file_list) != len(tgt_file_list):
            raise ValueError(
                'The same number of filepaths must be passed in for source and target validation.'
            )

        dataloaders = []
        prepend_idx = 0
        for idx, src_file in enumerate(src_file_list):
            if self.multilingual:
                prepend_idx = idx
            dataset = TranslationDataset(
                dataset_src=str(Path(src_file).expanduser()),
                dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction",
                                               False),
                prepend_id=self.multilingual_ids[prepend_idx]
                if self.multilingual else None,
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)

            if cfg.shuffle:
                sampler = pt_data.RandomSampler(dataset)
            else:
                sampler = pt_data.SequentialSampler(dataset)

            dataloader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                sampler=sampler,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
            dataloaders.append(dataloader)

        return dataloaders

    def setup_pre_and_post_processing_utils(self, source_lang, target_lang):
        """
        Creates source and target processor objects for input and output pre/post-processing.
        """
        self.source_processor, self.target_processor = None, None

        if self.encoder_tokenizer_library == 'byte-level':
            self.source_processor = ByteLevelProcessor()
        elif (source_lang == 'en'
              and target_lang == 'ja') or (source_lang == 'ja'
                                           and target_lang == 'en'):
            self.source_processor = EnJaProcessor(source_lang)
        elif source_lang == 'zh':
            self.source_processor = ChineseProcessor()
        elif source_lang == 'hi':
            self.source_processor = IndicProcessor(source_lang)
        elif source_lang is not None and source_lang not in ['ja', 'zh', 'hi']:
            self.source_processor = MosesProcessor(source_lang)

        if self.decoder_tokenizer_library == 'byte-level':
            self.target_processor = ByteLevelProcessor()
        elif (source_lang == 'en'
              and target_lang == 'ja') or (source_lang == 'ja'
                                           and target_lang == 'en'):
            self.target_processor = EnJaProcessor(target_lang)
        elif target_lang == 'zh':
            self.target_processor = ChineseProcessor()
        elif target_lang == 'hi':
            self.target_processor = IndicProcessor(target_lang)
        elif target_lang is not None and target_lang not in ['ja', 'zh', 'hi']:
            self.target_processor = MosesProcessor(target_lang)

        return self.source_processor, self.target_processor

    def ids_to_postprocessed_text(self,
                                  beam_ids,
                                  tokenizer,
                                  processor,
                                  filter_beam_ids=True):
        if filter_beam_ids:
            beam_ids = self.filter_predicted_ids(beam_ids)
        translations = [
            tokenizer.ids_to_text(tr) for tr in beam_ids.cpu().numpy()
        ]
        if processor is not None:
            translations = [
                processor.detokenize(translation.split(' '))
                for translation in translations
            ]
        return translations

    @torch.no_grad()
    def batch_translate(self,
                        src: torch.LongTensor,
                        src_mask: torch.LongTensor,
                        return_beam_scores: bool = False):
        """	
        Translates a minibatch of inputs from source language to target language.	
        Args:	
            src: minibatch of inputs in the src language (batch x seq_len)	
            src_mask: mask tensor indicating elements to be ignored (batch x seq_len)	
        Returns:	
            translations: a list strings containing detokenized translations	
            inputs: a list of string containing detokenized inputs	
        """
        mode = self.training
        try:
            self.eval()
            src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
            best_translations = self.beam_search(
                encoder_hidden_states=src_hiddens,
                encoder_input_mask=src_mask,
                return_beam_scores=return_beam_scores)
            if return_beam_scores:
                all_translations, scores, best_translations = best_translations
                scores = scores.view(-1)
                all_translations = self.ids_to_postprocessed_text(
                    all_translations,
                    self.decoder_tokenizer,
                    self.target_processor,
                    filter_beam_ids=True)

            best_translations = self.ids_to_postprocessed_text(
                best_translations,
                self.decoder_tokenizer,
                self.target_processor,
                filter_beam_ids=True)
            inputs = self.ids_to_postprocessed_text(src,
                                                    self.encoder_tokenizer,
                                                    self.source_processor,
                                                    filter_beam_ids=False)

        finally:
            self.train(mode=mode)
        if return_beam_scores:
            return inputs, all_translations, scores.data.cpu().numpy().tolist(
            ), best_translations

        return inputs, best_translations

    def prepare_inference_batch(self, text, prepend_ids=[], target=False):
        inputs = []
        processor = self.source_processor if not target else self.target_processor
        tokenizer = self.encoder_tokenizer if not target else self.decoder_tokenizer
        for txt in text:
            if processor is not None:
                txt = processor.normalize(txt)
                txt = processor.tokenize(txt)
            ids = tokenizer.text_to_ids(txt)
            ids = prepend_ids + [tokenizer.bos_id] + ids + [tokenizer.eos_id]
            inputs.append(ids)
        max_len = max(len(txt) for txt in inputs)
        src_ids_ = np.ones((len(inputs), max_len)) * tokenizer.pad_id
        for i, txt in enumerate(inputs):
            src_ids_[i][:len(txt)] = txt

        src_mask = torch.FloatTensor(
            (src_ids_ != tokenizer.pad_id)).to(self.device)
        src = torch.LongTensor(src_ids_).to(self.device)

        return src, src_mask

    # TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language
    @torch.no_grad()
    def translate(self,
                  text: List[str],
                  source_lang: str = None,
                  target_lang: str = None,
                  return_beam_scores: bool = False) -> List[str]:
        """
        Translates list of sentences from source language to target language.
        Should be regular text, this method performs its own tokenization/de-tokenization
        Args:
            text: list of strings to translate
            source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run
            target_lang: if not None, corresponding MosesDecokenizer will be run
        Returns:
            list of translated strings
        """
        # __TODO__: This will reset both source and target processors even if you want to reset just one.
        if source_lang is not None or target_lang is not None:
            self.setup_pre_and_post_processing_utils(source_lang, target_lang)

        mode = self.training
        prepend_ids = []
        if self.multilingual:
            if source_lang is None or target_lang is None:
                raise ValueError(
                    "Expect source_lang and target_lang to infer for multilingual model."
                )
            src_symbol = self.encoder_tokenizer.token_to_id('<' + source_lang +
                                                            '>')
            tgt_symbol = self.encoder_tokenizer.token_to_id('<' + target_lang +
                                                            '>')
            prepend_ids = [
                src_symbol
                if src_symbol in self.multilingual_ids else tgt_symbol
            ]
        try:
            self.eval()
            src, src_mask = self.prepare_inference_batch(text, prepend_ids)
            if return_beam_scores:
                _, all_translations, scores, best_translations = self.batch_translate(
                    src, src_mask, return_beam_scores=True)
                return all_translations, scores, best_translations
            else:
                _, translations = self.batch_translate(
                    src, src_mask, return_beam_scores=False)
        finally:
            self.train(mode=mode)
        return translations

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_de_transformer12x2",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_de_transformer12x2/versions/1.0.0rc1/files/nmt_en_de_transformer12x2.nemo",
            description=
            "En->De translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_de_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_de_en_transformer12x2",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_de_en_transformer12x2/versions/1.0.0rc1/files/nmt_de_en_transformer12x2.nemo",
            description=
            "De->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_de_en_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_es_transformer12x2",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_es_transformer12x2/versions/1.0.0rc1/files/nmt_en_es_transformer12x2.nemo",
            description=
            "En->Es translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_es_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_es_en_transformer12x2",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_es_en_transformer12x2/versions/1.0.0rc1/files/nmt_es_en_transformer12x2.nemo",
            description=
            "Es->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_es_en_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_fr_transformer12x2",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_fr_transformer12x2/versions/1.0.0rc1/files/nmt_en_fr_transformer12x2.nemo",
            description=
            "En->Fr translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_fr_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_fr_en_transformer12x2",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_fr_en_transformer12x2/versions/1.0.0rc1/files/nmt_fr_en_transformer12x2.nemo",
            description=
            "Fr->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_fr_en_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_ru_transformer6x6",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_ru_transformer6x6/versions/1.0.0rc1/files/nmt_en_ru_transformer6x6.nemo",
            description=
            "En->Ru translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_ru_transformer6x6",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_ru_en_transformer6x6",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_ru_en_transformer6x6/versions/1.0.0rc1/files/nmt_ru_en_transformer6x6.nemo",
            description=
            "Ru->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_ru_en_transformer6x6",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_zh_en_transformer6x6",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_zh_en_transformer6x6/versions/1.0.0rc1/files/nmt_zh_en_transformer6x6.nemo",
            description=
            "Zh->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_zh_en_transformer6x6",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_zh_transformer6x6",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_zh_transformer6x6/versions/1.0.0rc1/files/nmt_en_zh_transformer6x6.nemo",
            description=
            "En->Zh translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_zh_transformer6x6",
        )
        result.append(model)

        return result
コード例 #14
0
ファイル: mt_enc_dec_model.py プロジェクト: tiborkiss/NeMo
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language = cfg.get("src_language", None)
        self.tgt_language = cfg.get("tgt_language", None)

        self.multilingual = cfg.get("multilingual", False)
        self.multilingual_ids = []

        self.encoder_tokenizer_library = cfg.encoder_tokenizer.get(
            'library', 'yttm')
        self.decoder_tokenizer_library = cfg.decoder_tokenizer.get(
            'library', 'yttm')

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_library=self.encoder_tokenizer_library,
            encoder_tokenizer_model=cfg.encoder_tokenizer.get(
                'tokenizer_model'),
            encoder_bpe_dropout=cfg.encoder_tokenizer.get(
                'bpe_dropout', 0.0) if cfg.encoder_tokenizer.get(
                    'bpe_dropout', 0.0) is not None else 0.0,
            encoder_model_name=cfg.encoder.get('model_name') if hasattr(
                cfg.encoder, 'model_name') else None,
            encoder_r2l=cfg.encoder_tokenizer.get('r2l', False),
            decoder_tokenizer_library=self.decoder_tokenizer_library,
            encoder_tokenizer_vocab_file=cfg.encoder_tokenizer.get(
                'vocab_file', None),
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get(
                'bpe_dropout', 0.0) if cfg.decoder_tokenizer.get(
                    'bpe_dropout', 0.0) is not None else 0.0,
            decoder_model_name=cfg.decoder.get('model_name') if hasattr(
                cfg.decoder, 'model_name') else None,
            decoder_r2l=cfg.decoder_tokenizer.get('r2l', False),
        )

        if self.multilingual:
            if isinstance(self.src_language, ListConfig) and isinstance(
                    self.tgt_language, ListConfig):
                raise ValueError(
                    "cfg.src_language and cfg.tgt_language cannot both be lists. We only support many-to-one or one-to-many multilingual models."
                )
            elif isinstance(self.src_language, ListConfig):
                for lng in self.src_language:
                    self.multilingual_ids.append(
                        self.encoder_tokenizer.token_to_id("<" + lng + ">"))
            elif isinstance(self.tgt_language, ListConfig):
                for lng in self.tgt_language:
                    self.multilingual_ids.append(
                        self.encoder_tokenizer.token_to_id("<" + lng + ">"))
            else:
                raise ValueError(
                    "Expect either cfg.src_language or cfg.tgt_language to be a list when multilingual=True."
                )

            if isinstance(self.src_language, ListConfig):
                self.tgt_language = [self.tgt_language] * len(
                    self.src_language)
            else:
                self.src_language = [self.src_language] * len(
                    self.tgt_language)

            self.source_processor_list = []
            self.target_processor_list = []
            for src_lng, tgt_lng in zip(self.src_language, self.tgt_language):
                src_prcsr, tgt_prscr = self.setup_pre_and_post_processing_utils(
                    src_lng, tgt_lng)
                self.source_processor_list.append(src_prcsr)
                self.target_processor_list.append(tgt_prscr)

        else:
            # After this call, the model will have  self.source_processor and self.target_processor objects
            self.setup_pre_and_post_processing_utils(self.src_language,
                                                     self.tgt_language)
            self.multilingual_ids = [None]

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        checkpoint_file = encoder_cfg_dict.pop('checkpoint_file', None)
        self.encoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=encoder_cfg_dict,
            encoder=True,
            pre_ln_final_layer_norm=encoder_cfg_dict.get(
                'pre_ln_final_layer_norm', False),
            checkpoint_file=checkpoint_file,
        )

        # decoder from NeMo, Megatron-LM, or HuggingFace
        decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder'))
        decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size
        library = decoder_cfg_dict.pop('library', 'nemo')
        model_name = decoder_cfg_dict.pop('model_name', None)
        pretrained = decoder_cfg_dict.pop('pretrained', False)
        decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
        self.decoder = get_transformer(
            library=library,
            model_name=model_name,
            pretrained=pretrained,
            config_dict=decoder_cfg_dict,
            encoder=False,
            pre_ln_final_layer_norm=decoder_cfg_dict.get(
                'pre_ln_final_layer_norm', False),
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5

        # initialize weights if not using pretrained encoder/decoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        if not self._cfg.decoder.get('pretrained', False):
            self.decoder.apply(lambda module: transformer_weights_init(
                module, std_init_range))

        self.log_softmax.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss_fn = NLLLoss(ignore_index=self.decoder_tokenizer.pad_id)
コード例 #15
0
ファイル: transformer_lm_model.py プロジェクト: zt706/NeMo
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset

        vocab_file = cfg.language_model.get("vocab_file", None)

        if vocab_file is not None:
            vocab_file = self.register_artifact("language_model.vocab_file",
                                                vocab_file)

        tokenizer_model = cfg.language_model.get("tokenizer_model", None)

        if tokenizer_model is not None:
            tokenizer_model = self.register_artifact(
                "language_model.tokenizer_model", tokenizer_model)

        if cfg.language_model.special_tokens:
            special_tokens = OmegaConf.to_container(
                cfg.language_model.special_tokens, resolve=True)
        else:
            special_tokens = None

        self.tokenizer = get_tokenizer(
            tokenizer_name=cfg.language_model.tokenizer,
            vocab_file=vocab_file,
            special_tokens=special_tokens,
            tokenizer_model=tokenizer_model,
        )

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        self.embedding_layer = TransformerEmbedding(
            vocab_size=vocab_size,
            hidden_size=cfg.language_model.hidden_size,
            max_sequence_length=cfg.language_model.max_seq_length,
            embedding_dropout=cfg.language_model.get("embedding_dropout", 0.0),
            learn_positional_encodings=False,
        )
        self.encoder = TransformerEncoder(
            num_layers=cfg.language_model.num_layers,
            hidden_size=cfg.language_model.hidden_size,
            mask_future=True,
            num_attention_heads=cfg.language_model.num_attn_heads,
            inner_size=cfg.language_model.inner_size,
            ffn_dropout=cfg.language_model.get("ffn_dropout", 0.0),
            hidden_act=cfg.language_model.get("inner_activation", "relu"),
            attn_score_dropout=cfg.language_model.get("attn_score_dropout",
                                                      0.0),
            attn_layer_dropout=cfg.language_model.get("attn_layer_dropout",
                                                      0.0),
        )
        self.log_softmax = TokenClassifier(
            hidden_size=cfg.language_model.hidden_size,
            num_classes=vocab_size,
            log_softmax=True,
        )

        std_init_range = 1 / math.sqrt(cfg.language_model.hidden_size)
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.embedding_layer.token_embedding.weight

        if hasattr(self.tokenizer, 'pad_token'):
            pad_id = self.tokenizer.pad_id
        else:
            raise ValueError(
                "The tokenizer must support a special `pad_token`. Provide it using"
                "the `special_tokens` dictionary.")

        self.training_loss = SmoothedCrossEntropyLoss(pad_id=pad_id)
        self.validation_loss = SmoothedCrossEntropyLoss(
            pad_id=pad_id,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

        self.training_perplexity = Perplexity(dist_sync_on_step=True)
        self.validation_perplexity = Perplexity(compute_on_step=False)

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization()
コード例 #16
0
ファイル: mt_enc_dec_model.py プロジェクト: zt706/NeMo
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)
        self.setup_enc_dec_tokenizers(cfg)

        super().__init__(cfg=cfg, trainer=trainer)

        # TODO: use get_encoder function with support for HF and Megatron
        self.encoder = TransformerEncoderNM(
            vocab_size=self.encoder_vocab_size,
            hidden_size=cfg.encoder.hidden_size,
            num_layers=cfg.encoder.num_layers,
            inner_size=cfg.encoder.inner_size,
            max_sequence_length=cfg.encoder.max_sequence_length if hasattr(
                cfg.encoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.encoder.embedding_dropout if hasattr(
                cfg.encoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.encoder.learn_positional_encodings
            if hasattr(cfg.encoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.encoder.num_attention_heads,
            ffn_dropout=cfg.encoder.ffn_dropout,
            attn_score_dropout=cfg.encoder.attn_score_dropout,
            attn_layer_dropout=cfg.encoder.attn_layer_dropout,
            hidden_act=cfg.encoder.hidden_act,
            mask_future=cfg.encoder.mask_future,
            pre_ln=cfg.encoder.pre_ln,
        )

        # TODO: user get_decoder function with support for HF and Megatron
        self.decoder = TransformerDecoderNM(
            vocab_size=self.decoder_vocab_size,
            hidden_size=cfg.decoder.hidden_size,
            num_layers=cfg.decoder.num_layers,
            inner_size=cfg.decoder.inner_size,
            max_sequence_length=cfg.decoder.max_sequence_length if hasattr(
                cfg.decoder, 'max_sequence_length') else 512,
            embedding_dropout=cfg.decoder.embedding_dropout if hasattr(
                cfg.decoder, 'embedding_dropout') else 0.0,
            learn_positional_encodings=cfg.decoder.learn_positional_encodings
            if hasattr(cfg.decoder, 'learn_positional_encodings') else False,
            num_attention_heads=cfg.decoder.num_attention_heads,
            ffn_dropout=cfg.decoder.ffn_dropout,
            attn_score_dropout=cfg.decoder.attn_score_dropout,
            attn_layer_dropout=cfg.decoder.attn_layer_dropout,
            hidden_act=cfg.decoder.hidden_act,
            pre_ln=cfg.decoder.pre_ln,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size**0.5
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id,
            label_smoothing=cfg.label_smoothing)
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False,
                                                 take_avg_loss=True)
コード例 #17
0
class TokenClassificationModel(NLPModel, Exportable):
    """Token Classification Model with BERT, applicable for tasks such as Named Entity Recognition"""
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.class_weights = None
        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory and get data stats with Data Descriptor
        Weights are later used to setup loss

        Args:
            data_dir: path to data directory
        """
        self._cfg.dataset.data_dir = data_dir
        logging.info(f'Setting model.dataset.data_dir to {data_dir}.')

    def setup_loss(self, class_balancing: str = None):
        """Setup loss
           Setup or update loss.

        Args:
            class_balancing: whether to use class weights during training
        """
        if class_balancing == 'weighted_loss' and self.class_weights:
            # you may need to increase the number of epochs for convergence when using weighted_loss
            loss = CrossEntropyLoss(logits_ndim=3, weight=self.class_weights)
        else:
            loss = CrossEntropyLoss(logits_ndim=3)
        return loss

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch
        logits = self(input_ids=input_ids,
                      token_type_ids=input_type_ids,
                      attention_mask=input_mask)
        loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch
        logits = self(input_ids=input_ids,
                      token_type_ids=input_type_ids,
                      attention_mask=input_mask)
        val_loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)

        subtokens_mask = subtokens_mask > 0.5

        preds = torch.argmax(logits, axis=-1)[subtokens_mask]
        labels = labels[subtokens_mask]
        tp, fn, fp, _ = self.classification_report(preds, labels)

        return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and classification report
        precision, recall, f1, report = self.classification_report.compute()

        logging.info(report)

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('precision', precision)
        self.log('f1', f1)
        self.log('recall', recall)

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        """
        return self.validation_epoch_end(outputs)

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            vocab_file=self.register_artifact(
                config_path='tokenizer.vocab_file', src=cfg.vocab_file),
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            tokenizer_model=self.register_artifact(
                config_path='tokenizer.tokenizer_model',
                src=cfg.tokenizer_model),
        )
        self.tokenizer = tokenizer

    def setup_training_data(self,
                            train_data_config: Optional[DictConfig] = None):
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   train_data_config.labels_file)
        label_ids, label_ids_filename, self.class_weights = get_label_ids(
            label_file=labels_file,
            is_training=True,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=self._cfg.dataset.class_balancing == 'weighted_loss',
        )
        # save label maps to the config
        self._cfg.label_ids = OmegaConf.create(label_ids)
        self.register_artifact('label_ids.csv', label_ids_filename)
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self,
                              val_data_config: Optional[DictConfig] = None):
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   val_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   test_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader:
        """
        Setup dataloader from config
        Args:
            cfg: config for the dataloader
        Return:
            Pytorch Dataloader
        """
        dataset_cfg = self._cfg.dataset
        data_dir = dataset_cfg.data_dir

        if not os.path.exists(data_dir):
            raise FileNotFoundError(
                f"Data directory is not found at: {data_dir}.")

        text_file = os.path.join(data_dir, cfg.text_file)
        labels_file = os.path.join(data_dir, cfg.labels_file)

        if not (os.path.exists(text_file) and os.path.exists(labels_file)):
            raise FileNotFoundError(
                f'{text_file} or {labels_file} not found. The data should be split into 2 files: text.txt and \
                labels.txt. Each line of the text.txt file contains text sequences, where words are separated with \
                spaces. The labels.txt file contains corresponding labels for each word in text.txt, the labels are \
                separated with spaces. Each line of the files should follow the format:  \
                   [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \
                   [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).')
        dataset = BertTokenClassificationDataset(
            text_file=text_file,
            label_file=labels_file,
            max_seq_length=dataset_cfg.max_seq_length,
            tokenizer=self.tokenizer,
            num_samples=cfg.num_samples,
            pad_label=dataset_cfg.pad_label,
            label_ids=self._cfg.label_ids,
            ignore_extra_tokens=dataset_cfg.ignore_extra_tokens,
            ignore_start_end=dataset_cfg.ignore_start_end,
            use_cache=dataset_cfg.use_cache,
        )
        return DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=dataset_cfg.num_workers,
            pin_memory=dataset_cfg.pin_memory,
            drop_last=dataset_cfg.drop_last,
        )

    def _setup_infer_dataloader(
            self, queries: List[str],
            batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: text
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertTokenClassificationInferDataset(tokenizer=self.tokenizer,
                                                      queries=queries,
                                                      max_seq_length=-1)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    @torch.no_grad()
    def _infer(self, queries: List[str], batch_size: int = None) -> List[int]:
        """
        Get prediction for the queries
        Args:
            queries: text sequences
            batch_size: batch size to use during inference.
        Returns:
            all_preds: model predictions
        """
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                preds = tensor2list(
                    torch.argmax(logits, axis=-1)[subtokens_mask])
                all_preds.extend(preds)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return all_preds

    def add_predictions(self,
                        queries: Union[List[str], str],
                        batch_size: int = 32) -> List[str]:
        """
        Add predicted token labels to the queries. Use this method for debugging and prototyping.
        Args:
            queries: text
            batch_size: batch size to use during inference.
        Returns:
            result: text with added entities
        """
        if queries is None or len(queries) == 0:
            return []

        result = []
        all_preds = self._infer(queries, batch_size)

        queries = [q.strip().split() for q in queries]
        num_words = [len(q) for q in queries]
        if sum(num_words) != len(all_preds):
            raise ValueError('Pred and words must have the same length')

        ids_to_labels = {v: k for k, v in self._cfg.label_ids.items()}
        start_idx = 0
        end_idx = 0
        for query in queries:
            end_idx += len(query)

            # extract predictions for the current query from the list of all predictions
            preds = all_preds[start_idx:end_idx]
            start_idx = end_idx

            query_with_entities = ''
            for j, word in enumerate(query):
                # strip out the punctuation to attach the entity tag to the word not to a punctuation mark
                # that follows the word
                if word[-1].isalpha():
                    punct = ''
                else:
                    punct = word[-1]
                    word = word[:-1]

                query_with_entities += word
                label = ids_to_labels[preds[j]]

                if label != self._cfg.dataset.pad_label:
                    query_with_entities += '[' + label + ']'
                query_with_entities += punct + ' '
            result.append(query_with_entities.strip())
        return result

    def evaluate_from_file(
        self,
        output_dir: str,
        text_file: str,
        labels_file: Optional[str] = None,
        add_confusion_matrix: Optional[bool] = False,
        normalize_confusion_matrix: Optional[bool] = True,
        batch_size: int = 1,
    ) -> None:
        """
        Run inference on data from a file, plot confusion matrix and calculate classification report.
        Use this method for final evaluation.

        Args:
            output_dir: path to output directory to store model predictions, confusion matrix plot (if set to True)
            text_file: path to file with text. Each line of the text.txt file contains text sequences, where words
                are separated with spaces: [WORD] [SPACE] [WORD] [SPACE] [WORD]
            labels_file (Optional): path to file with labels. Each line of the labels_file should contain
                labels corresponding to each word in the text_file, the labels are separated with spaces:
                [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).'
            add_confusion_matrix: whether to generate confusion matrix
            normalize_confusion_matrix: whether to normalize confusion matrix
            batch_size: batch size to use during inference.
        """
        output_dir = os.path.abspath(output_dir)

        with open(text_file, 'r') as f:
            queries = f.readlines()

        all_preds = self._infer(queries, batch_size)
        with_labels = labels_file is not None
        if with_labels:
            with open(labels_file, 'r') as f:
                all_labels_str = f.readlines()
                all_labels_str = ' '.join(
                    [labels.strip() for labels in all_labels_str])

        # writing labels and predictions to a file in output_dir is specified in the config
        os.makedirs(output_dir, exist_ok=True)
        filename = os.path.join(output_dir,
                                'infer_' + os.path.basename(text_file))
        try:
            with open(filename, 'w') as f:
                if with_labels:
                    f.write('labels\t' + all_labels_str + '\n')
                    logging.info(f'Labels save to {filename}')

                # convert labels from string label to ids
                ids_to_labels = {v: k for k, v in self._cfg.label_ids.items()}
                all_preds_str = [ids_to_labels[pred] for pred in all_preds]
                f.write('preds\t' + ' '.join(all_preds_str) + '\n')
                logging.info(f'Predictions saved to {filename}')

            if with_labels and add_confusion_matrix:
                all_labels = all_labels_str.split()
                # convert labels from string label to ids
                label_ids = self._cfg.label_ids
                all_labels = [label_ids[label] for label in all_labels]

                plot_confusion_matrix(all_labels,
                                      all_preds,
                                      output_dir,
                                      label_ids=label_ids,
                                      normalize=normalize_confusion_matrix)
                logging.info(
                    get_classification_report(all_labels, all_preds,
                                              label_ids))
        except Exception:
            logging.error(
                f'When providing a file with labels, check that all labels in {labels_file} were'
                f'seen during training.')
            raise

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="NERModel",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/NERModel.nemo",
            description=
            "The model was trained on GMB (Groningen Meaning Bank) corpus for entity recognition and achieves 74.61 F1 Macro score.",
        )
        result.append(model)
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " TokenClassificationModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "TKCL")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])
コード例 #18
0
    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_library=cfg.encoder_tokenizer.get('library', 'yttm'),
            encoder_tokenizer_model=cfg.encoder_tokenizer.get('tokenizer_model'),
            encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
            encoder_model_name=cfg.encoder.get('model_name') if hasattr(cfg.encoder, 'model_name') else None,
            decoder_tokenizer_library=cfg.decoder_tokenizer.get('library', 'yttm'),
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
            decoder_model_name=cfg.decoder.get('model_name') if hasattr(cfg.decoder, 'model_name') else None,
        )

        # After this call, the model will have  self.source_processor and self.target_processor objects
        self.setup_pre_and_post_processing_utils(source_lang=self.src_language, target_lang=self.tgt_language)

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library, model_name=model_name, pretrained=pretrained, config_dict=encoder_cfg_dict, encoder=True,
        )

        # decoder from NeMo, Megatron-LM, or HuggingFace
        decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder'))
        decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size
        library = decoder_cfg_dict.pop('library', 'nemo')
        model_name = decoder_cfg_dict.pop('model_name', None)
        pretrained = decoder_cfg_dict.pop('pretrained', False)
        decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
        self.decoder = get_transformer(
            library=library, model_name=model_name, pretrained=pretrained, config_dict=decoder_cfg_dict, encoder=False,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size ** 0.5

        # initialize weights if not using pretrained encoder/decoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(module, std_init_range))

        if not self._cfg.decoder.get('pretrained', False):
            self.decoder.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing
        )
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True)
コード例 #19
0
class MTEncDecModel(EncDecNLPModel):
    """
    Encoder-decoder machine translation model.
    """

    def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None):
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0

        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        cfg = model_utils.maybe_update_config_version(cfg)

        self.src_language: str = cfg.get("src_language", None)
        self.tgt_language: str = cfg.get("tgt_language", None)

        # Instantiates tokenizers and register to be saved with NeMo Model archive
        # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer
        # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly.
        self.setup_enc_dec_tokenizers(
            encoder_tokenizer_library=cfg.encoder_tokenizer.get('library', 'yttm'),
            encoder_tokenizer_model=cfg.encoder_tokenizer.get('tokenizer_model'),
            encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0),
            encoder_model_name=cfg.encoder.get('model_name') if hasattr(cfg.encoder, 'model_name') else None,
            decoder_tokenizer_library=cfg.decoder_tokenizer.get('library', 'yttm'),
            decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model,
            decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0),
            decoder_model_name=cfg.decoder.get('model_name') if hasattr(cfg.decoder, 'model_name') else None,
        )

        # After this call, the model will have  self.source_processor and self.target_processor objects
        self.setup_pre_and_post_processing_utils(source_lang=self.src_language, target_lang=self.tgt_language)

        # TODO: Why is this base constructor call so late in the game?
        super().__init__(cfg=cfg, trainer=trainer)

        # encoder from NeMo, Megatron-LM, or HuggingFace
        encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder'))
        encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size
        library = encoder_cfg_dict.pop('library', 'nemo')
        model_name = encoder_cfg_dict.pop('model_name', None)
        pretrained = encoder_cfg_dict.pop('pretrained', False)
        self.encoder = get_transformer(
            library=library, model_name=model_name, pretrained=pretrained, config_dict=encoder_cfg_dict, encoder=True,
        )

        # decoder from NeMo, Megatron-LM, or HuggingFace
        decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder'))
        decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size
        library = decoder_cfg_dict.pop('library', 'nemo')
        model_name = decoder_cfg_dict.pop('model_name', None)
        pretrained = decoder_cfg_dict.pop('pretrained', False)
        decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size
        self.decoder = get_transformer(
            library=library, model_name=model_name, pretrained=pretrained, config_dict=decoder_cfg_dict, encoder=False,
        )

        self.log_softmax = TokenClassifier(
            hidden_size=self.decoder.hidden_size,
            num_classes=self.decoder_vocab_size,
            activation=cfg.head.activation,
            log_softmax=cfg.head.log_softmax,
            dropout=cfg.head.dropout,
            use_transformer_init=cfg.head.use_transformer_init,
        )

        self.beam_search = BeamSearchSequenceGenerator(
            embedding=self.decoder.embedding,
            decoder=self.decoder.decoder,
            log_softmax=self.log_softmax,
            max_sequence_length=self.decoder.max_sequence_length,
            beam_size=cfg.beam_size,
            bos=self.decoder_tokenizer.bos_id,
            pad=self.decoder_tokenizer.pad_id,
            eos=self.decoder_tokenizer.eos_id,
            len_pen=cfg.len_pen,
            max_delta_length=cfg.max_generation_delta,
        )

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight

        # TODO: encoder and decoder with different hidden size?
        std_init_range = 1 / self.encoder.hidden_size ** 0.5

        # initialize weights if not using pretrained encoder/decoder
        if not self._cfg.encoder.get('pretrained', False):
            self.encoder.apply(lambda module: transformer_weights_init(module, std_init_range))

        if not self._cfg.decoder.get('pretrained', False):
            self.decoder.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

        self.loss_fn = SmoothedCrossEntropyLoss(
            pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing
        )
        self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True)

    def filter_predicted_ids(self, ids):
        ids[ids >= self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id
        return ids

    @typecheck()
    def forward(self, src, src_mask, tgt, tgt_mask):
        src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
        tgt_hiddens = self.decoder(
            input_ids=tgt, decoder_mask=tgt_mask, encoder_embeddings=src_hiddens, encoder_mask=src_mask
        )
        log_probs = self.log_softmax(hidden_states=tgt_hiddens)
        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)
        train_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        tensorboard_logs = {
            'train_loss': train_loss,
            'lr': self._optimizer.param_groups[0]['lr'],
        }
        return {'loss': train_loss, 'log': tensorboard_logs}

    def eval_step(self, batch, batch_idx, mode):
        for i in range(len(batch)):
            if batch[i].ndim == 3:
                # Dataset returns already batched data and the first dimension of size 1 added by DataLoader
                # is excess.
                batch[i] = batch[i].squeeze(dim=0)
        src_ids, src_mask, tgt_ids, tgt_mask, labels = batch
        log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask)

        # this will run encoder twice -- TODO: potentially fix
        _, translations = self.batch_translate(src=src_ids, src_mask=src_mask)
        eval_loss = self.loss_fn(log_probs=log_probs, labels=labels)
        self.eval_loss(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1])
        np_tgt = tgt_ids.detach().cpu().numpy()
        ground_truths = [self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt]
        ground_truths = [self.target_processor.detokenize(tgt.split(' ')) for tgt in ground_truths]
        num_non_pad_tokens = np.not_equal(np_tgt, self.decoder_tokenizer.pad_id).sum().item()
        return {
            'translations': translations,
            'ground_truths': ground_truths,
            'num_non_pad_tokens': num_non_pad_tokens,
        }

    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, 'test')

    @rank_zero_only
    def log_param_stats(self):
        for name, p in self.named_parameters():
            if p.requires_grad:
                self.trainer.logger.experiment.add_histogram(name + '_hist', p, global_step=self.global_step)
                self.trainer.logger.experiment.add_scalars(
                    name,
                    {'mean': p.mean(), 'stddev': p.std(), 'max': p.max(), 'min': p.min()},
                    global_step=self.global_step,
                )

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        return self.eval_step(batch, batch_idx, 'val')

    def eval_epoch_end(self, outputs, mode):
        eval_loss = self.eval_loss.compute()
        translations = list(itertools.chain(*[x['translations'] for x in outputs]))
        ground_truths = list(itertools.chain(*[x['ground_truths'] for x in outputs]))

        assert len(translations) == len(ground_truths)
        if self.tgt_language in ['ja']:
            sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="ja-mecab")
        elif self.tgt_language in ['zh']:
            sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="zh")
        else:
            sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="13a")

        dataset_name = "Validation" if mode == 'val' else "Test"
        logging.info(f"\n\n\n\n{dataset_name} set size: {len(translations)}")
        logging.info(f"{dataset_name} Sacre BLEU = {sacre_bleu.score}")
        logging.info(f"{dataset_name} TRANSLATION EXAMPLES:".upper())
        for i in range(0, 3):
            ind = random.randint(0, len(translations) - 1)
            logging.info("    " + '\u0332'.join(f"EXAMPLE {i}:"))
            logging.info(f"    Prediction:   {translations[ind]}")
            logging.info(f"    Ground Truth: {ground_truths[ind]}")

        ans = {f"{mode}_loss": eval_loss, f"{mode}_sacreBLEU": sacre_bleu.score}
        ans['log'] = dict(ans)
        return ans

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        self.log_dict(self.eval_epoch_end(outputs, 'val'), sync_dist=True)

    def test_epoch_end(self, outputs):
        return self.eval_epoch_end(outputs, 'test')

    def setup_enc_dec_tokenizers(
        self,
        encoder_tokenizer_library=None,
        encoder_tokenizer_model=None,
        encoder_bpe_dropout=0.0,
        encoder_model_name=None,
        decoder_tokenizer_library=None,
        decoder_tokenizer_model=None,
        decoder_bpe_dropout=0.0,
        decoder_model_name=None,
    ):

        supported_tokenizers = ['yttm', 'huggingface', 'sentencepiece']
        if (
            encoder_tokenizer_library not in supported_tokenizers
            or decoder_tokenizer_library not in supported_tokenizers
        ):
            raise NotImplementedError(f"Currently we only support tokenizers in {supported_tokenizers}.")

        self.encoder_tokenizer = get_nmt_tokenizer(
            library=encoder_tokenizer_library,
            tokenizer_model=self.register_artifact("cfg.encoder_tokenizer.tokenizer_model", encoder_tokenizer_model),
            bpe_dropout=encoder_bpe_dropout,
            model_name=encoder_model_name,
            vocab_file=None,
            special_tokens=None,
            use_fast=False,
        )
        self.decoder_tokenizer = get_nmt_tokenizer(
            library=decoder_tokenizer_library,
            tokenizer_model=self.register_artifact("cfg.decoder_tokenizer.tokenizer_model", decoder_tokenizer_model),
            bpe_dropout=decoder_bpe_dropout,
            model_name=decoder_model_name,
            vocab_file=None,
            special_tokens=None,
            use_fast=False,
        )

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("load_from_cached_dataset", False):
            logging.info('Loading from cached dataset %s' % (cfg.src_file_name))
            if cfg.src_file_name != cfg.tgt_file_name:
                raise ValueError("src must be equal to target for cached dataset")
            dataset = pickle.load(open(cfg.src_file_name, 'rb'))
            dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False)
        elif cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError("Trying to use tarred data set but could not find metadata path in config.")
            else:
                metadata_file = cfg.get('metadata_file')
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if cfg.get('tar_files') is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(f'Loading from tarred dataset {tar_files}')
                    else:
                        raise FileNotFoundError("Could not find tarred dataset in config or metadata.")
                else:
                    tar_files = cfg.get('tar_files')
                    if metadata_file.get('tar_files') is not None:
                        raise ValueError(
                            'Tar files specified in config and in metadata file. Tar files should only be specified once.'
                        )
            dataset = TarredTranslationDataset(
                text_tar_filepaths=tar_files,
                metadata_path=metadata_file,
                encoder_tokenizer=self.encoder_tokenizer,
                decoder_tokenizer=self.decoder_tokenizer,
                shuffle_n=cfg.get("tar_shuffle_n", 100),
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
                reverse_lang_direction=cfg.get("reverse_lang_direction", False),
            )
            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            dataset = TranslationDataset(
                dataset_src=str(Path(cfg.src_file_name).expanduser()),
                dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction", False),
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    def setup_pre_and_post_processing_utils(self, source_lang, target_lang):
        """
        Creates source and target processor objects for input and output pre/post-processing.
        """
        self.source_processor, self.target_processor = None, None
        if (source_lang == 'en' and target_lang == 'ja') or (source_lang == 'ja' and target_lang == 'en'):
            self.source_processor = EnJaProcessor(source_lang)
            self.target_processor = EnJaProcessor(target_lang)
        else:
            if source_lang == 'zh':
                self.source_processor = ChineseProcessor()
            if target_lang == 'zh':
                self.target_processor = ChineseProcessor()
            if source_lang is not None and source_lang not in ['ja', 'zh']:
                self.source_processor = MosesProcessor(source_lang)
            if target_lang is not None and target_lang not in ['ja', 'zh']:
                self.target_processor = MosesProcessor(target_lang)

    @torch.no_grad()
    def batch_translate(
        self, src: torch.LongTensor, src_mask: torch.LongTensor,
    ):
        """	
        Translates a minibatch of inputs from source language to target language.	
        Args:	
            src: minibatch of inputs in the src language (batch x seq_len)	
            src_mask: mask tensor indicating elements to be ignored (batch x seq_len)	
        Returns:	
            translations: a list strings containing detokenized translations	
            inputs: a list of string containing detokenized inputs	
        """
        mode = self.training
        try:
            self.eval()

            src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask)
            beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask)
            beam_results = self.filter_predicted_ids(beam_results)

            translations = [self.decoder_tokenizer.ids_to_text(tr) for tr in beam_results.cpu().numpy()]
            inputs = [self.encoder_tokenizer.ids_to_text(inp) for inp in src.cpu().numpy()]
            if self.target_processor is not None:
                translations = [
                    self.target_processor.detokenize(translation.split(' ')) for translation in translations
                ]

            if self.source_processor is not None:
                inputs = [self.source_processor.detokenize(item.split(' ')) for item in inputs]

        finally:
            self.train(mode=mode)
        return inputs, translations

    # TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language
    @torch.no_grad()
    def translate(self, text: List[str], source_lang: str = None, target_lang: str = None) -> List[str]:
        """
        Translates list of sentences from source language to target language.
        Should be regular text, this method performs its own tokenization/de-tokenization
        Args:
            text: list of strings to translate
            source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run
            target_lang: if not None, corresponding MosesDecokenizer will be run
        Returns:
            list of translated strings
        """
        # __TODO__: This will reset both source and target processors even if you want to reset just one.
        if source_lang is not None or target_lang is not None:
            self.setup_pre_and_post_processing_utils(source_lang, target_lang)

        mode = self.training
        try:
            self.eval()
            inputs = []
            for txt in text:
                if self.source_processor is not None:
                    txt = self.source_processor.normalize(txt)
                    txt = self.source_processor.tokenize(txt)
                ids = self.encoder_tokenizer.text_to_ids(txt)
                ids = [self.encoder_tokenizer.bos_id] + ids + [self.encoder_tokenizer.eos_id]
                inputs.append(ids)
            max_len = max(len(txt) for txt in inputs)
            src_ids_ = np.ones((len(inputs), max_len)) * self.encoder_tokenizer.pad_id
            for i, txt in enumerate(inputs):
                src_ids_[i][: len(txt)] = txt

            src_mask = torch.FloatTensor((src_ids_ != self.encoder_tokenizer.pad_id)).to(self.device)
            src = torch.LongTensor(src_ids_).to(self.device)
            _, translations = self.batch_translate(src, src_mask)
        finally:
            self.train(mode=mode)
        return translations

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_de_transformer12x2",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_de_transformer12x2/versions/1.0.0rc1/files/nmt_en_de_transformer12x2.nemo",
            description="En->De translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_de_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_de_en_transformer12x2",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_de_en_transformer12x2/versions/1.0.0rc1/files/nmt_de_en_transformer12x2.nemo",
            description="De->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_de_en_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_es_transformer12x2",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_es_transformer12x2/versions/1.0.0rc1/files/nmt_en_es_transformer12x2.nemo",
            description="En->Es translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_es_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_es_en_transformer12x2",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_es_en_transformer12x2/versions/1.0.0rc1/files/nmt_es_en_transformer12x2.nemo",
            description="Es->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_es_en_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_fr_transformer12x2",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_fr_transformer12x2/versions/1.0.0rc1/files/nmt_en_fr_transformer12x2.nemo",
            description="En->Fr translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_fr_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_fr_en_transformer12x2",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_fr_en_transformer12x2/versions/1.0.0rc1/files/nmt_fr_en_transformer12x2.nemo",
            description="Fr->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_fr_en_transformer12x2",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_ru_transformer6x6",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_ru_transformer6x6/versions/1.0.0rc1/files/nmt_en_ru_transformer6x6.nemo",
            description="En->Ru translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_ru_transformer6x6",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_ru_en_transformer6x6",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_ru_en_transformer6x6/versions/1.0.0rc1/files/nmt_ru_en_transformer6x6.nemo",
            description="Ru->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_ru_en_transformer6x6",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_zh_en_transformer6x6",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_zh_en_transformer6x6/versions/1.0.0rc1/files/nmt_zh_en_transformer6x6.nemo",
            description="Zh->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_zh_en_transformer6x6",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="nmt_en_zh_transformer6x6",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_zh_transformer6x6/versions/1.0.0rc1/files/nmt_en_zh_transformer6x6.nemo",
            description="En->Zh translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_zh_transformer6x6",
        )
        result.append(model)

        return result