Example #1
0
class QAModel(NLPModel, Exportable):
    """
    BERT encoder with QA head training.
    """
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.token_classifier.num_classes,
            num_layers=cfg.token_classifier.num_layers,
            activation=cfg.token_classifier.activation,
            log_softmax=cfg.token_classifier.log_softmax,
            dropout=cfg.token_classifier.dropout,
            use_transformer_init=cfg.token_classifier.use_transformer_init,
        )

        self.loss = SpanningLoss()

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        input_ids, input_type_ids, input_mask, unique_ids, start_positions, end_positions = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)
        loss, _, _ = self.loss(logits=logits,
                               start_positions=start_positions,
                               end_positions=end_positions)

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)
        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx):
        if self.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        input_ids, input_type_ids, input_mask, unique_ids, start_positions, end_positions = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)
        loss, start_logits, end_logits = self.loss(
            logits=logits,
            start_positions=start_positions,
            end_positions=end_positions)

        tensors = {
            'unique_ids': unique_ids,
            'start_logits': start_logits,
            'end_logits': end_logits,
        }
        self.log(f'{prefix}_loss', loss)
        return {f'{prefix}_loss': loss, f'{prefix}_tensors': tensors}

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        if self.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        avg_loss = torch.stack([x[f'{prefix}_loss'] for x in outputs]).mean()

        unique_ids = torch.cat(
            [x[f'{prefix}_tensors']['unique_ids'] for x in outputs])
        start_logits = torch.cat(
            [x[f'{prefix}_tensors']['start_logits'] for x in outputs])
        end_logits = torch.cat(
            [x[f'{prefix}_tensors']['end_logits'] for x in outputs])

        all_unique_ids = []
        all_start_logits = []
        all_end_logits = []
        if torch.distributed.is_initialized():
            world_size = torch.distributed.get_world_size()
            for ind in range(world_size):
                all_unique_ids.append(torch.empty_like(unique_ids))
                all_start_logits.append(torch.empty_like(start_logits))
                all_end_logits.append(torch.empty_like(end_logits))
            torch.distributed.all_gather(all_unique_ids, unique_ids)
            torch.distributed.all_gather(all_start_logits, start_logits)
            torch.distributed.all_gather(all_end_logits, end_logits)
        else:
            all_unique_ids.append(unique_ids)
            all_start_logits.append(start_logits)
            all_end_logits.append(end_logits)

        exact_match, f1, all_predictions, all_nbest = -1, -1, [], []
        if not torch.distributed.is_initialized(
        ) or torch.distributed.get_rank() == 0:

            unique_ids = []
            start_logits = []
            end_logits = []
            for u in all_unique_ids:
                unique_ids.extend(tensor2list(u))
            for u in all_start_logits:
                start_logits.extend(tensor2list(u))
            for u in all_end_logits:
                end_logits.extend(tensor2list(u))

            eval_dataset = self._test_dl.dataset if self.testing else self._validation_dl.dataset
            exact_match, f1, all_predictions, all_nbest = eval_dataset.evaluate(
                unique_ids=unique_ids,
                start_logits=start_logits,
                end_logits=end_logits,
                n_best_size=self._cfg.dataset.n_best_size,
                max_answer_length=self._cfg.dataset.max_answer_length,
                version_2_with_negative=self._cfg.dataset.
                version_2_with_negative,
                null_score_diff_threshold=self._cfg.dataset.
                null_score_diff_threshold,
                do_lower_case=self._cfg.dataset.do_lower_case,
            )

        logging.info(f"{prefix} exact match {exact_match}")
        logging.info(f"{prefix} f1 {f1}")

        self.log(f'{prefix}_loss', avg_loss)
        self.log(f'{prefix}_exact_match', exact_match)
        self.log(f'{prefix}_f1', f1)

    def test_epoch_end(self, outputs):
        return self.validation_epoch_end(outputs)

    @torch.no_grad()
    def inference(
        self,
        file: str,
        batch_size: int = 1,
        num_samples: int = -1,
        output_nbest_file: Optional[str] = None,
        output_prediction_file: Optional[str] = None,
    ):
        """
        Get prediction for unlabeled inference data
        Args:
            file: inference data
            batch_size: batch size to use during inference
            num_samples: number of samples to use of inference data. Default: -1 if all data should be used.
            output_nbest_file: optional output file for writing out nbest list
            output_prediction_file: optional output file for writing out predictions
        Returns:
            all_predictions: model predictions
            all_nbest: model nbest list
        """
        # store predictions for all queries in a single list
        all_predictions = []
        all_nbest = []
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {
                "batch_size": batch_size,
                "file": file,
                "shuffle": False,
                "num_samples": num_samples,
                'num_workers': 2,
                'pin_memory': False,
                'drop_last': False,
            }
            dataloader_cfg = OmegaConf.create(dataloader_cfg)
            infer_datalayer = self._setup_dataloader_from_config(
                cfg=dataloader_cfg, mode=INFERENCE_MODE)

            all_logits = []
            all_unique_ids = []
            for i, batch in enumerate(infer_datalayer):
                input_ids, token_type_ids, attention_mask, unique_ids = batch
                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=token_type_ids.to(device),
                    attention_mask=attention_mask.to(device),
                )
                all_logits.append(logits)
                all_unique_ids.append(unique_ids)
            logits = torch.cat(all_logits)
            unique_ids = tensor2list(torch.cat(all_unique_ids))
            s, e = logits.split(dim=-1, split_size=1)
            start_logits = tensor2list(s.squeeze())
            end_logits = tensor2list(e.squeeze())
            (all_predictions, all_nbest,
             scores_diff) = infer_datalayer.dataset.get_predictions(
                 unique_ids=unique_ids,
                 start_logits=start_logits,
                 end_logits=end_logits,
                 n_best_size=self._cfg.dataset.n_best_size,
                 max_answer_length=self._cfg.dataset.max_answer_length,
                 version_2_with_negative=self._cfg.dataset.
                 version_2_with_negative,
                 null_score_diff_threshold=self._cfg.dataset.
                 null_score_diff_threshold,
                 do_lower_case=self._cfg.dataset.do_lower_case,
             )

            with open(file, 'r') as test_file_fp:
                test_data = json.load(test_file_fp)["data"]
                id_to_question_mapping = {}
                for title in test_data:
                    for par in title["paragraphs"]:
                        for question in par["qas"]:
                            id_to_question_mapping[
                                question["id"]] = question["question"]

            for question_id in all_predictions:
                all_predictions[question_id] = (
                    id_to_question_mapping[question_id],
                    all_predictions[question_id])

            if output_nbest_file is not None:
                with open(output_nbest_file, "w") as writer:
                    writer.write(json.dumps(all_nbest, indent=4) + "\n")
            if output_prediction_file is not None:
                with open(output_prediction_file, "w") as writer:
                    writer.write(json.dumps(all_predictions, indent=4) + "\n")

        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)

        return all_predictions, all_nbest

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.file:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, mode=TRAINING_MODE)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.file:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, mode=EVALUATION_MODE)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or test_data_config.file is None:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, mode=EVALUATION_MODE)

    def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str):
        dataset = SquadDataset(
            tokenizer=self.tokenizer,
            data_file=cfg.file,
            doc_stride=self._cfg.dataset.doc_stride,
            max_query_length=self._cfg.dataset.max_query_length,
            max_seq_length=self._cfg.dataset.max_seq_length,
            version_2_with_negative=self._cfg.dataset.version_2_with_negative,
            num_samples=cfg.num_samples,
            mode=mode,
            use_cache=self._cfg.dataset.use_cache,
        )

        dl = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            collate_fn=dataset.collate_fn,
            drop_last=cfg.drop_last,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
        )
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="BERTBaseUncasedSQuADv1.1",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTBaseUncasedSQuADv1.1.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Base Uncased on SQuAD v1.1 dataset which obtains an exact match (EM) score of 82.43% and an F1 score of 89.59%.",
        )
        result.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="BERTBaseUncasedSQuADv2.0",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTBaseUncasedSQuADv2.0.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Base Uncased on SQuAD v2.0 dataset which obtains an exact match (EM) score of 73.35% and an F1 score of 76.44%.",
        )
        result.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="BERTLargeUncasedSQuADv1.1",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTLargeUncasedSQuADv1.1.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Large Uncased on SQuAD v1.1 dataset which obtains an exact match (EM) score of 85.47% and an F1 score of 92.10%.",
        )
        result.append(model)
        model = PretrainedModelInfo(
            pretrained_model_name="BERTLargeUncasedSQuADv2.0",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/BERTLargeUncasedSQuADv2.0.nemo",
            description=
            "Question answering model finetuned from NeMo BERT Large Uncased on SQuAD v2.0 dataset which obtains an exact match (EM) score of 78.8% and an F1 score of 81.85%.",
        )
        result.append(model)
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " QAModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "QA")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])
Example #2
0
class PunctuationCapitalizationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
            "capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        capit_logits = self.capit_classifier(hidden_states=hidden_states)
        return punct_logits, capit_logits

    def _make_step(self, batch):
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
        punct_logits, capit_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
        capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
        return loss, punct_logits, capit_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        val_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        test_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory

        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        """Setup training data"""
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
            )

        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            self.register_artifact(
                self._cfg.class_labels.punct_labels_file, self._train_dl.dataset.punct_label_ids_file
            )
            self.register_artifact(
                self._cfg.class_labels.capit_labels_file, self._train_dl.dataset.capit_label_ids_file
            )

            # save label maps to the config
            self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
            self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)

    def setup_validation_data(self, val_data_config: Optional[Dict] = None):
        """
        Setup validaton data

        val_data_config: validation data config
        """
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        # use data_dir specified in the ds_item to run evaluation on multiple datasets
        if 'ds_item' in cfg and cfg.ds_item is not None:
            data_dir = cfg.ds_item
        else:
            data_dir = self._cfg.dataset.data_dir

        text_file = os.path.join(data_dir, cfg.text_file)
        label_file = os.path.join(data_dir, cfg.labels_file)

        dataset = BertPunctuationCapitalizationDataset(
            tokenizer=self.tokenizer,
            text_file=text_file,
            label_file=label_file,
            pad_label=self._cfg.dataset.pad_label,
            punct_label_ids=self._cfg.punct_label_ids,
            capit_label_ids=self._cfg.capit_label_ids,
            max_seq_length=self._cfg.dataset.max_seq_length,
            ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
            ignore_start_end=self._cfg.dataset.ignore_start_end,
            use_cache=self._cfg.dataset.use_cache,
            num_samples=cfg.num_samples,
            punct_label_ids_file=self._cfg.class_labels.punct_labels_file
            if 'class_labels' in self._cfg
            else 'punct_label_ids.csv',
            capit_label_ids_file=self._cfg.class_labels.capit_labels_file
            if 'class_labels' in self._cfg
            else 'capit_label_ids.csv',
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=self._cfg.dataset.drop_last,
        )

    def _setup_infer_dataloader(self, queries: List[str], batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertPunctuationCapitalizationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=self._cfg.dataset.max_seq_length
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    def add_punctuation_capitalization(self, queries: List[str], batch_size: int = None) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for debugging and prototyping.
        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
        Returns:
            result: text with added capitalization and punctuation
        """
        if queries is None or len(queries) == 0:
            return []
        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')

        # We will store the output here
        result = []

        # Model's mode and device
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self = self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # store predictions for all queries in a single list
            all_punct_preds = []
            all_capit_preds = []

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                punct_logits, capit_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                punct_preds = tensor2list(torch.argmax(punct_logits, axis=-1)[subtokens_mask])
                capit_preds = tensor2list(torch.argmax(capit_logits, axis=-1)[subtokens_mask])
                all_punct_preds.extend(punct_preds)
                all_capit_preds.extend(capit_preds)

            queries = [q.strip().split() for q in queries]
            queries_len = [len(q) for q in queries]

            if sum(queries_len) != len(all_punct_preds) or sum(queries_len) != len(all_capit_preds):
                raise ValueError('Pred and words must have the same length')

            punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
            capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}

            start_idx = 0
            end_idx = 0
            for query in queries:
                end_idx += len(query)

                # extract predictions for the current query from the list of all predictions
                punct_preds = all_punct_preds[start_idx:end_idx]
                capit_preds = all_capit_preds[start_idx:end_idx]
                start_idx = end_idx

                query_with_punct_and_capit = ''
                for j, word in enumerate(query):
                    punct_label = punct_ids_to_labels[punct_preds[j]]
                    capit_label = capit_ids_to_labels[capit_preds[j]]

                    if capit_label != self._cfg.dataset.pad_label:
                        word = word.capitalize()
                    query_with_punct_and_capit += word
                    if punct_label != self._cfg.dataset.pad_label:
                        query_with_punct_and_capit += punct_label
                    query_with_punct_and_capit += ' '

                result.append(query_with_punct_and_capit.strip())
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_BERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_BERT.nemo",
                description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_DistilBERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_DistilBERT.nemo",
                description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        """
        Unlike other models' export() this one creates 5 output files, not 3:
        punct_<output> - fused punctuation model (BERT+PunctuationClassifier)
        capit_<output> - fused capitalization model (BERT+CapitalizationClassifier)
        bert_<output> - common BERT neural net
        punct_classifier_<output> - Punctuation Classifier neural net
        capt_classifier_<output> - Capitalization Classifier neural net
        """
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " PunctuationCapitalizationModel consists of three separate models with different"
                " inputs and outputs."
            )

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output), 'punct_classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Punctuation Classifier exported to ONNX'
        punct_classifier_onnx = self.punct_classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output3 = os.path.join(os.path.dirname(output), 'capit_classifier_' + os.path.basename(output))
        output3_descr = qual_name + ' Capitalization Classifier exported to ONNX'
        capit_classifier_onnx = self.capit_classifier.export(
            output3,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        punct_output_model = attach_onnx_to_onnx(bert_model_onnx, punct_classifier_onnx, "PTCL")
        output4 = os.path.join(os.path.dirname(output), 'punct_' + os.path.basename(output))
        output4_descr = qual_name + ' Punctuation BERT+Classifier exported to ONNX'
        onnx.save(punct_output_model, output4)
        capit_output_model = attach_onnx_to_onnx(bert_model_onnx, capit_classifier_onnx, "CPCL")
        output5 = os.path.join(os.path.dirname(output), 'capit_' + os.path.basename(output))
        output5_descr = qual_name + ' Capitalization BERT+Classifier exported to ONNX'
        onnx.save(capit_output_model, output5)
        return (
            [output1, output2, output3, output4, output5],
            [output1_descr, output2_descr, output3_descr, output4_descr, output5_descr],
        )
Example #3
0
class TokenClassificationModel(NLPModel, Exportable):
    """Token Classification Model with BERT, applicable for tasks such as Named Entity Recognition"""
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.class_weights = None
        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory and get data stats with Data Descriptor
        Weights are later used to setup loss

        Args:
            data_dir: path to data directory
        """
        self._cfg.dataset.data_dir = data_dir
        logging.info(f'Setting model.dataset.data_dir to {data_dir}.')

    def setup_loss(self, class_balancing: str = None):
        """Setup loss
           Setup or update loss.

        Args:
            class_balancing: whether to use class weights during training
        """
        if class_balancing == 'weighted_loss' and self.class_weights:
            # you may need to increase the number of epochs for convergence when using weighted_loss
            loss = CrossEntropyLoss(logits_ndim=3, weight=self.class_weights)
        else:
            loss = CrossEntropyLoss(logits_ndim=3)
        return loss

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch
        logits = self(input_ids=input_ids,
                      token_type_ids=input_type_ids,
                      attention_mask=input_mask)
        loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch
        logits = self(input_ids=input_ids,
                      token_type_ids=input_type_ids,
                      attention_mask=input_mask)
        val_loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)

        subtokens_mask = subtokens_mask > 0.5

        preds = torch.argmax(logits, axis=-1)[subtokens_mask]
        labels = labels[subtokens_mask]
        tp, fn, fp, _ = self.classification_report(preds, labels)

        return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and classification report
        precision, recall, f1, report = self.classification_report.compute()

        logging.info(report)

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('precision', precision)
        self.log('f1', f1)
        self.log('recall', recall)

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        """
        return self.validation_epoch_end(outputs)

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            vocab_file=self.register_artifact(
                config_path='tokenizer.vocab_file', src=cfg.vocab_file),
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            tokenizer_model=self.register_artifact(
                config_path='tokenizer.tokenizer_model',
                src=cfg.tokenizer_model),
        )
        self.tokenizer = tokenizer

    def setup_training_data(self,
                            train_data_config: Optional[DictConfig] = None):
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   train_data_config.labels_file)
        label_ids, label_ids_filename, self.class_weights = get_label_ids(
            label_file=labels_file,
            is_training=True,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=self._cfg.dataset.class_balancing == 'weighted_loss',
        )
        # save label maps to the config
        self._cfg.label_ids = OmegaConf.create(label_ids)
        self.register_artifact('label_ids.csv', label_ids_filename)
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self,
                              val_data_config: Optional[DictConfig] = None):
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   val_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   test_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader:
        """
        Setup dataloader from config
        Args:
            cfg: config for the dataloader
        Return:
            Pytorch Dataloader
        """
        dataset_cfg = self._cfg.dataset
        data_dir = dataset_cfg.data_dir

        if not os.path.exists(data_dir):
            raise FileNotFoundError(
                f"Data directory is not found at: {data_dir}.")

        text_file = os.path.join(data_dir, cfg.text_file)
        labels_file = os.path.join(data_dir, cfg.labels_file)

        if not (os.path.exists(text_file) and os.path.exists(labels_file)):
            raise FileNotFoundError(
                f'{text_file} or {labels_file} not found. The data should be split into 2 files: text.txt and \
                labels.txt. Each line of the text.txt file contains text sequences, where words are separated with \
                spaces. The labels.txt file contains corresponding labels for each word in text.txt, the labels are \
                separated with spaces. Each line of the files should follow the format:  \
                   [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \
                   [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).')
        dataset = BertTokenClassificationDataset(
            text_file=text_file,
            label_file=labels_file,
            max_seq_length=dataset_cfg.max_seq_length,
            tokenizer=self.tokenizer,
            num_samples=cfg.num_samples,
            pad_label=dataset_cfg.pad_label,
            label_ids=self._cfg.label_ids,
            ignore_extra_tokens=dataset_cfg.ignore_extra_tokens,
            ignore_start_end=dataset_cfg.ignore_start_end,
            use_cache=dataset_cfg.use_cache,
        )
        return DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=dataset_cfg.num_workers,
            pin_memory=dataset_cfg.pin_memory,
            drop_last=dataset_cfg.drop_last,
        )

    def _setup_infer_dataloader(
            self, queries: List[str],
            batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: text
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertTokenClassificationInferDataset(tokenizer=self.tokenizer,
                                                      queries=queries,
                                                      max_seq_length=-1)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    @torch.no_grad()
    def _infer(self, queries: List[str], batch_size: int = None) -> List[int]:
        """
        Get prediction for the queries
        Args:
            queries: text sequences
            batch_size: batch size to use during inference.
        Returns:
            all_preds: model predictions
        """
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                preds = tensor2list(
                    torch.argmax(logits, axis=-1)[subtokens_mask])
                all_preds.extend(preds)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return all_preds

    def add_predictions(self,
                        queries: Union[List[str], str],
                        batch_size: int = 32) -> List[str]:
        """
        Add predicted token labels to the queries. Use this method for debugging and prototyping.
        Args:
            queries: text
            batch_size: batch size to use during inference.
        Returns:
            result: text with added entities
        """
        if queries is None or len(queries) == 0:
            return []

        result = []
        all_preds = self._infer(queries, batch_size)

        queries = [q.strip().split() for q in queries]
        num_words = [len(q) for q in queries]
        if sum(num_words) != len(all_preds):
            raise ValueError('Pred and words must have the same length')

        ids_to_labels = {v: k for k, v in self._cfg.label_ids.items()}
        start_idx = 0
        end_idx = 0
        for query in queries:
            end_idx += len(query)

            # extract predictions for the current query from the list of all predictions
            preds = all_preds[start_idx:end_idx]
            start_idx = end_idx

            query_with_entities = ''
            for j, word in enumerate(query):
                # strip out the punctuation to attach the entity tag to the word not to a punctuation mark
                # that follows the word
                if word[-1].isalpha():
                    punct = ''
                else:
                    punct = word[-1]
                    word = word[:-1]

                query_with_entities += word
                label = ids_to_labels[preds[j]]

                if label != self._cfg.dataset.pad_label:
                    query_with_entities += '[' + label + ']'
                query_with_entities += punct + ' '
            result.append(query_with_entities.strip())
        return result

    def evaluate_from_file(
        self,
        output_dir: str,
        text_file: str,
        labels_file: Optional[str] = None,
        add_confusion_matrix: Optional[bool] = False,
        normalize_confusion_matrix: Optional[bool] = True,
        batch_size: int = 1,
    ) -> None:
        """
        Run inference on data from a file, plot confusion matrix and calculate classification report.
        Use this method for final evaluation.

        Args:
            output_dir: path to output directory to store model predictions, confusion matrix plot (if set to True)
            text_file: path to file with text. Each line of the text.txt file contains text sequences, where words
                are separated with spaces: [WORD] [SPACE] [WORD] [SPACE] [WORD]
            labels_file (Optional): path to file with labels. Each line of the labels_file should contain
                labels corresponding to each word in the text_file, the labels are separated with spaces:
                [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).'
            add_confusion_matrix: whether to generate confusion matrix
            normalize_confusion_matrix: whether to normalize confusion matrix
            batch_size: batch size to use during inference.
        """
        output_dir = os.path.abspath(output_dir)

        with open(text_file, 'r') as f:
            queries = f.readlines()

        all_preds = self._infer(queries, batch_size)
        with_labels = labels_file is not None
        if with_labels:
            with open(labels_file, 'r') as f:
                all_labels_str = f.readlines()
                all_labels_str = ' '.join(
                    [labels.strip() for labels in all_labels_str])

        # writing labels and predictions to a file in output_dir is specified in the config
        os.makedirs(output_dir, exist_ok=True)
        filename = os.path.join(output_dir,
                                'infer_' + os.path.basename(text_file))
        try:
            with open(filename, 'w') as f:
                if with_labels:
                    f.write('labels\t' + all_labels_str + '\n')
                    logging.info(f'Labels save to {filename}')

                # convert labels from string label to ids
                ids_to_labels = {v: k for k, v in self._cfg.label_ids.items()}
                all_preds_str = [ids_to_labels[pred] for pred in all_preds]
                f.write('preds\t' + ' '.join(all_preds_str) + '\n')
                logging.info(f'Predictions saved to {filename}')

            if with_labels and add_confusion_matrix:
                all_labels = all_labels_str.split()
                # convert labels from string label to ids
                label_ids = self._cfg.label_ids
                all_labels = [label_ids[label] for label in all_labels]

                plot_confusion_matrix(all_labels,
                                      all_preds,
                                      output_dir,
                                      label_ids=label_ids,
                                      normalize=normalize_confusion_matrix)
                logging.info(
                    get_classification_report(all_labels, all_preds,
                                              label_ids))
        except Exception:
            logging.error(
                f'When providing a file with labels, check that all labels in {labels_file} were'
                f'seen during training.')
            raise

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="NERModel",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/NERModel.nemo",
            description=
            "The model was trained on GMB (Groningen Meaning Bank) corpus for entity recognition and achieves 74.61 F1 Macro score.",
        )
        result.append(model)
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " TokenClassificationModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "TKCL")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])