class IntentSlotClassificationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """

        self.data_dir = cfg.data_dir
        self.max_seq_length = cfg.language_model.max_seq_length

        self.data_desc = IntentSlotDataDesc(
            data_dir=cfg.data_dir,
            modes=[cfg.train_ds.prefix, cfg.validation_ds.prefix])

        self._setup_tokenizer(cfg.tokenizer)
        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=self.data_desc.num_intents,
            num_slots=self.data_desc.num_slots,
            dropout=cfg.head.fc_dropout,
            num_layers=cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[cfg.intent_loss_weight, 1.0 - cfg.intent_loss_weight])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=self.data_desc.num_intents,
            label_ids=self.data_desc.intents_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=self.data_desc.num_slots,
            label_ids=self.data_desc.slots_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory and get data stats with Data Descriptor
        Weights are later used to setup loss

        Args:
            data_dir: path to data directory
        """
        self.data_dir = data_dir
        logging.info(f'Setting model.data_dir to {data_dir}.')

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        intent_logits, slot_logits = self.classifier(
            hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(preds, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
        }

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute(
        )
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute(
        )
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        input_file = f'{self.data_dir}/{cfg.prefix}.tsv'
        slot_file = f'{self.data_dir}/{cfg.prefix}_slots.tsv'

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f'{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files.')

        dataset = IntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.data_desc.pad_label,
            ignore_extra_tokens=self._cfg.ignore_extra_tokens,
            ignore_start_end=self._cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def _setup_infer_dataloader(
            self, queries: List[str],
            batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: text
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """
        dataset = IntentSlotInferenceDataset(tokenizer=self.tokenizer,
                                             queries=queries,
                                             max_seq_length=-1,
                                             do_lower_case=False)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.test_ds.num_workers,
            pin_memory=self._cfg.test_ds.pin_memory,
            drop_last=False,
        )

    def predict_from_examples(self,
                              queries: List[str],
                              batch_size: int = 32) -> List[List[str]]:
        """
        Get prediction for the queries (intent and slots)
        Args:
            queries: text sequences
            batch_size: batch size to use during inference
        Returns:
            predicted_intents, predicted_slots: model intent and slot predictions
        """
        predicted_intents = []
        predicted_slots = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # load intent and slot labels from the dictionary files (user should have them in a data directory)
            intent_labels, slot_labels = IntentSlotDataDesc.intent_slot_dicts(
                self.data_dir)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents and slots for these examples
                # intents
                intent_preds = tensor2list(torch.argmax(intent_logits,
                                                        axis=-1))

                # convert numerical outputs to Intent and Slot labels from the dictionaries
                for intent_num in intent_preds:
                    if intent_num < len(intent_labels):
                        predicted_intents.append(intent_labels[intent_num])
                    else:
                        # should not happen
                        predicted_intents.append("Unknown Intent")

                # slots
                slot_preds = torch.argmax(slot_logits, axis=-1)

                for slot_preds_query, mask_query in zip(
                        slot_preds, subtokens_mask):
                    query_slots = ''
                    for slot, mask in zip(slot_preds_query, mask_query):
                        if mask == 1:
                            if slot < len(slot_labels):
                                query_slots += slot_labels[slot] + ' '
                            else:
                                query_slots += 'Unknown_slot '
                    predicted_slots.append(query_slots.strip())

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return predicted_intents, predicted_slots

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description=
            "This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " IntentSlotClassificationModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "ISC")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])
Example #2
0
class PunctuationCapitalizationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
            "capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        capit_logits = self.capit_classifier(hidden_states=hidden_states)
        return punct_logits, capit_logits

    def _make_step(self, batch):
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
        punct_logits, capit_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
        capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
        return loss, punct_logits, capit_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        val_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        test_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory

        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        """Setup training data"""
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
            )

        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            self.register_artifact(
                self._cfg.class_labels.punct_labels_file, self._train_dl.dataset.punct_label_ids_file
            )
            self.register_artifact(
                self._cfg.class_labels.capit_labels_file, self._train_dl.dataset.capit_label_ids_file
            )

            # save label maps to the config
            self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
            self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)

    def setup_validation_data(self, val_data_config: Optional[Dict] = None):
        """
        Setup validaton data

        val_data_config: validation data config
        """
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        # use data_dir specified in the ds_item to run evaluation on multiple datasets
        if 'ds_item' in cfg and cfg.ds_item is not None:
            data_dir = cfg.ds_item
        else:
            data_dir = self._cfg.dataset.data_dir

        text_file = os.path.join(data_dir, cfg.text_file)
        label_file = os.path.join(data_dir, cfg.labels_file)

        dataset = BertPunctuationCapitalizationDataset(
            tokenizer=self.tokenizer,
            text_file=text_file,
            label_file=label_file,
            pad_label=self._cfg.dataset.pad_label,
            punct_label_ids=self._cfg.punct_label_ids,
            capit_label_ids=self._cfg.capit_label_ids,
            max_seq_length=self._cfg.dataset.max_seq_length,
            ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
            ignore_start_end=self._cfg.dataset.ignore_start_end,
            use_cache=self._cfg.dataset.use_cache,
            num_samples=cfg.num_samples,
            punct_label_ids_file=self._cfg.class_labels.punct_labels_file
            if 'class_labels' in self._cfg
            else 'punct_label_ids.csv',
            capit_label_ids_file=self._cfg.class_labels.capit_labels_file
            if 'class_labels' in self._cfg
            else 'capit_label_ids.csv',
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=self._cfg.dataset.drop_last,
        )

    def _setup_infer_dataloader(self, queries: List[str], batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertPunctuationCapitalizationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=self._cfg.dataset.max_seq_length
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    def add_punctuation_capitalization(self, queries: List[str], batch_size: int = None) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for debugging and prototyping.
        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
        Returns:
            result: text with added capitalization and punctuation
        """
        if queries is None or len(queries) == 0:
            return []
        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')

        # We will store the output here
        result = []

        # Model's mode and device
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self = self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # store predictions for all queries in a single list
            all_punct_preds = []
            all_capit_preds = []

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                punct_logits, capit_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                punct_preds = tensor2list(torch.argmax(punct_logits, axis=-1)[subtokens_mask])
                capit_preds = tensor2list(torch.argmax(capit_logits, axis=-1)[subtokens_mask])
                all_punct_preds.extend(punct_preds)
                all_capit_preds.extend(capit_preds)

            queries = [q.strip().split() for q in queries]
            queries_len = [len(q) for q in queries]

            if sum(queries_len) != len(all_punct_preds) or sum(queries_len) != len(all_capit_preds):
                raise ValueError('Pred and words must have the same length')

            punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
            capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}

            start_idx = 0
            end_idx = 0
            for query in queries:
                end_idx += len(query)

                # extract predictions for the current query from the list of all predictions
                punct_preds = all_punct_preds[start_idx:end_idx]
                capit_preds = all_capit_preds[start_idx:end_idx]
                start_idx = end_idx

                query_with_punct_and_capit = ''
                for j, word in enumerate(query):
                    punct_label = punct_ids_to_labels[punct_preds[j]]
                    capit_label = capit_ids_to_labels[capit_preds[j]]

                    if capit_label != self._cfg.dataset.pad_label:
                        word = word.capitalize()
                    query_with_punct_and_capit += word
                    if punct_label != self._cfg.dataset.pad_label:
                        query_with_punct_and_capit += punct_label
                    query_with_punct_and_capit += ' '

                result.append(query_with_punct_and_capit.strip())
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_BERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_BERT.nemo",
                description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_DistilBERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_DistilBERT.nemo",
                description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        """
        Unlike other models' export() this one creates 5 output files, not 3:
        punct_<output> - fused punctuation model (BERT+PunctuationClassifier)
        capit_<output> - fused capitalization model (BERT+CapitalizationClassifier)
        bert_<output> - common BERT neural net
        punct_classifier_<output> - Punctuation Classifier neural net
        capt_classifier_<output> - Capitalization Classifier neural net
        """
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " PunctuationCapitalizationModel consists of three separate models with different"
                " inputs and outputs."
            )

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output), 'punct_classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Punctuation Classifier exported to ONNX'
        punct_classifier_onnx = self.punct_classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output3 = os.path.join(os.path.dirname(output), 'capit_classifier_' + os.path.basename(output))
        output3_descr = qual_name + ' Capitalization Classifier exported to ONNX'
        capit_classifier_onnx = self.capit_classifier.export(
            output3,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        punct_output_model = attach_onnx_to_onnx(bert_model_onnx, punct_classifier_onnx, "PTCL")
        output4 = os.path.join(os.path.dirname(output), 'punct_' + os.path.basename(output))
        output4_descr = qual_name + ' Punctuation BERT+Classifier exported to ONNX'
        onnx.save(punct_output_model, output4)
        capit_output_model = attach_onnx_to_onnx(bert_model_onnx, capit_classifier_onnx, "CPCL")
        output5 = os.path.join(os.path.dirname(output), 'capit_' + os.path.basename(output))
        output5_descr = qual_name + ' Capitalization BERT+Classifier exported to ONNX'
        onnx.save(capit_output_model, output5)
        return (
            [output1, output2, output3, output4, output5],
            [output1_descr, output2_descr, output3_descr, output4_descr, output5_descr],
        )
Example #3
0
class IntentSlotClassificationModel(NLPModel):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Setup tokenizer.
        self.setup_tokenizer(cfg.tokenizer)

        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Disable setup methods.
            IntentSlotClassificationModel._set_model_restore_state(is_being_restored=True)
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds, cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Enable setup methods.
        IntentSlotClassificationModel._set_model_restore_state(is_being_restored=False)

        # Initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=self.cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file', cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(self.cfg.language_model.config)
            if self.cfg.language_model.config
            else None,
            checkpoint_file=self.cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file),
        )

        # Initialize Classifier.
        self._reconfigure_classifier()

    def _set_defaults_data_desc(self, cfg):
        """
        Method makes sure that cfg.data_desc params are set.
        If not, set's them to "dummy" defaults.
        """
        if not hasattr(cfg, "data_desc"):
            OmegaConf.set_struct(cfg, False)
            cfg.data_desc = {}
            # Intents.
            cfg.data_desc.intent_labels = " "
            cfg.data_desc.intent_label_ids = {" ": 0}
            cfg.data_desc.intent_weights = [1]
            # Slots.
            cfg.data_desc.slot_labels = " "
            cfg.data_desc.slot_label_ids = {" ": 0}
            cfg.data_desc.slot_weights = [1]

            cfg.data_desc.pad_label = "O"
            OmegaConf.set_struct(cfg, True)

    def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds):
        """ Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc. """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = IntentSlotDataDesc(data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create(
                {'intent_labels_file': 'intent_labels.csv', 'slot_labels_file': 'slot_labels.csv'}
            )

        slot_labels_file = os.path.join(data_dir, cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir, cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact(cfg.class_labels.intent_labels_file, intent_labels_file)
        self.register_artifact(cfg.class_labels.slot_labels_file, slot_labels_file)
        OmegaConf.set_struct(cfg, True)

    def _save_label_ids(self, label_ids: Dict[str, int], filename: str) -> None:
        """ Saves label ids map to a file """
        with open(filename, 'w') as out:
            labels, _ = zip(*sorted(label_ids.items(), key=lambda x: x[1]))
            out.write('\n'.join(labels))
            logging.info(f'Labels: {label_ids}')
            logging.info(f'Labels mapping saved to : {out.name}')

    def _reconfigure_classifier(self):
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(logits_ndim=2, weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2, weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight]
        )

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

    def update_data_dir_for_training(self, data_dir: str, train_ds, validation_ds) -> None:
        """
        Update data directory and get data stats with Data Descriptor.
        Also, reconfigures the classifier - to cope with data with e.g. different number of slots.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir
        # Update configuration with new data.
        self._set_data_desc_to_cfg(self.cfg, data_dir, train_ds, validation_ds)
        # Reconfigure the classifier for different settings (number of intents, slots etc.).
        self._reconfigure_classifier()

    def update_data_dir_for_testing(self, data_dir) -> None:
        """
        Update data directory.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        intent_logits, slot_logits = self.classifier(hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits, labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits, labels=slot_labels, loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits, labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits, labels=slot_labels, loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(preds, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
        }

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute()
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute()
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        input_file = f'{self.data_dir}/{cfg.prefix}.tsv'
        slot_file = f'{self.data_dir}/{cfg.prefix}_slots.tsv'

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f'{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files.'
            )

        dataset = IntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.cfg.data_desc.pad_label,
            ignore_extra_tokens=self.cfg.ignore_extra_tokens,
            ignore_start_end=self.cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def _setup_infer_dataloader(self, queries: List[str], test_ds) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: text
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """

        dataset = IntentSlotInferenceDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=-1, do_lower_case=False
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=test_ds.batch_size,
            shuffle=test_ds.shuffle,
            num_workers=test_ds.num_workers,
            pin_memory=test_ds.pin_memory,
            drop_last=test_ds.drop_last,
        )

    def predict_from_examples(self, queries: List[str], test_ds) -> List[List[str]]:
        """
        Get prediction for the queries (intent and slots)
        Args:
            queries: text sequences
            test_ds: Dataset configuration section.
        Returns:
            predicted_intents, predicted_slots: model intent and slot predictions
        """
        predicted_intents = []
        predicted_slots = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'

            # Retrieve intent and slot vocabularies from configuration.
            intent_labels = self.cfg.data_desc.intent_labels
            slot_labels = self.cfg.data_desc.slot_labels

            # Initialize tokenizer.
            # if not hasattr(self, "tokenizer"):
            #    self._setup_tokenizer(self.cfg.tokenizer)
            # Initialize modules.
            # self._reconfigure_classifier()

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            # Dataset.
            infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents and slots for these examples
                # intents
                intent_preds = tensor2list(torch.argmax(intent_logits, axis=-1))

                # convert numerical outputs to Intent and Slot labels from the dictionaries
                for intent_num in intent_preds:
                    if intent_num < len(intent_labels):
                        predicted_intents.append(intent_labels[int(intent_num)])
                    else:
                        # should not happen
                        predicted_intents.append("Unknown Intent")

                # slots
                slot_preds = torch.argmax(slot_logits, axis=-1)

                for slot_preds_query, mask_query in zip(slot_preds, subtokens_mask):
                    query_slots = ''
                    for slot, mask in zip(slot_preds_query, mask_query):
                        if mask == 1:
                            if slot < len(slot_labels):
                                query_slots += slot_labels[int(slot)] + ' '
                            else:
                                query_slots += 'Unknown_slot '
                    predicted_slots.append(query_slots.strip())

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return predicted_intents, predicted_slots

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description="This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result
Example #4
0
class PunctuationCapitalizationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
            "capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file', cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file),
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        capit_logits = self.capit_classifier(hidden_states=hidden_states)
        return punct_logits, capit_logits

    def _make_step(self, batch):
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
        punct_logits, capit_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
        capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
        return loss, punct_logits, capit_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        val_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        test_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

        self.punct_class_report.reset()
        self.capit_class_report.reset()

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory

        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        """Setup training data"""
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
            )

        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            self.register_artifact('class_labels.punct_labels_file', self._train_dl.dataset.punct_label_ids_file)
            self.register_artifact('class_labels.capit_labels_file', self._train_dl.dataset.capit_label_ids_file)

            # save label maps to the config
            self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
            self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)

    def setup_validation_data(self, val_data_config: Optional[Dict] = None):
        """
        Setup validaton data

        val_data_config: validation data config
        """
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        # use data_dir specified in the ds_item to run evaluation on multiple datasets
        if 'ds_item' in cfg and cfg.ds_item is not None:
            data_dir = cfg.ds_item
        else:
            data_dir = self._cfg.dataset.data_dir

        text_file = os.path.join(data_dir, cfg.text_file)
        label_file = os.path.join(data_dir, cfg.labels_file)

        dataset = BertPunctuationCapitalizationDataset(
            tokenizer=self.tokenizer,
            text_file=text_file,
            label_file=label_file,
            pad_label=self._cfg.dataset.pad_label,
            punct_label_ids=self._cfg.punct_label_ids,
            capit_label_ids=self._cfg.capit_label_ids,
            max_seq_length=self._cfg.dataset.max_seq_length,
            ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
            ignore_start_end=self._cfg.dataset.ignore_start_end,
            use_cache=self._cfg.dataset.use_cache,
            num_samples=cfg.num_samples,
            punct_label_ids_file=self._cfg.class_labels.punct_labels_file
            if 'class_labels' in self._cfg
            else 'punct_label_ids.csv',
            capit_label_ids_file=self._cfg.class_labels.capit_labels_file
            if 'class_labels' in self._cfg
            else 'capit_label_ids.csv',
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=self._cfg.dataset.drop_last,
        )

    def _setup_infer_dataloader(
        self, queries: List[str], batch_size: int, max_seq_length: int, step: int, margin: int,
    ) -> torch.utils.data.DataLoader:
        """
        Setup function for a infer data loader.

        Args:
            model: a ``PunctuationCapitalizationModel`` instance for which data loader is created.
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
            max_seq_length: length of segments into which queries are split. ``max_seq_length`` includes ``[CLS]`` and
                ``[SEP]`` so every segment contains at most ``max_seq_length-2`` tokens from input a query.
            step: number of tokens by which a segment is offset to a previous segment. Parameter ``step`` cannot be greater
                than ``max_seq_length-2``.
            margin: number of tokens near the edge of a segment which label probabilities are not used in final prediction
                computation.
        Returns:
            A pytorch DataLoader.
        """
        if max_seq_length is None:
            max_seq_length = self._cfg.dataset.max_seq_length
        if step is None:
            step = self._cfg.dataset.step
        if margin is None:
            margin = self._cfg.dataset.margin

        dataset = BertPunctuationCapitalizationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=max_seq_length, step=step, margin=margin
        )
        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    @staticmethod
    def _remove_margins(tensor, margin_size, keep_left, keep_right):
        tensor = tensor.detach().clone()
        if not keep_left:
            tensor = tensor[margin_size + 1 :]  # remove left margin and CLS token
        if not keep_right:
            tensor = tensor[: tensor.shape[0] - margin_size - 1]  # remove right margin and SEP token
        return tensor

    def _transform_logit_to_prob_and_remove_margins_and_extract_word_probs(
        self,
        punct_logits: torch.Tensor,
        capit_logits: torch.Tensor,
        subtokens_mask: torch.Tensor,
        start_word_ids: Tuple[int],
        margin: int,
        is_first: Tuple[bool],
        is_last: Tuple[bool],
    ) -> Tuple[List[np.ndarray], List[np.ndarray], List[int]]:
        """
        Applies softmax to get punctuation and capitalization probabilities, applies ``subtokens_mask`` to extract
        probabilities for words from probabilities for tokens, removes ``margin`` probabilities near edges of a segment.
        Left margin of the first segment in a query and right margin of the last segment in a query are not removed.
        Calculates new ``start_word_ids`` taking into the account the margins. If the left margin of a segment is removed
        corresponding start word index is increased by number of words (number of nonzero values in corresponding
        ``subtokens_mask``) in the margin.
        Args:
            punct_logits: a float tensor of shape ``[batch_size, segment_length, number_of_punctuation_labels]``
            capit_logits: a float tensor of shape ``[batch_size, segment_length, number_of_capitalization_labels]``
            subtokens_mask: a float tensor of shape ``[batch_size, segment_length]``
            start_word_ids: indices of segment first words in a query
            margin: number of tokens near edges of a segment which probabilities are discarded
            is_first: is segment the first segment in a query
            is_last: is segment the last segment in a query
        Returns:
            b_punct_probs: list containing ``batch_size`` numpy arrays. The numpy arrays have shapes
                ``[number_of_word_in_this_segment, number_of_punctuation_labels]``. Word punctuation probabilities for
                segments in the batch.
            b_capit_probs: list containing ``batch_size`` numpy arrays. The numpy arrays have shapes
                ``[number_of_word_in_this_segment, number_of_capitalization_labels]``. Word capitalization probabilities for
                segments in the batch.
            new_start_word_ids: indices of segment first words in a query after margin removal
        """
        new_start_word_ids = list(start_word_ids)
        subtokens_mask = subtokens_mask > 0.5
        b_punct_probs, b_capit_probs = [], []
        for i, (first, last, pl, cl, stm) in enumerate(
            zip(is_first, is_last, punct_logits, capit_logits, subtokens_mask)
        ):
            if not first:
                new_start_word_ids[i] += torch.count_nonzero(stm[: margin + 1]).numpy()  # + 1 is for [CLS] token
            stm = self._remove_margins(stm, margin, keep_left=first, keep_right=last)
            for b_probs, logits in [(b_punct_probs, pl), (b_capit_probs, cl)]:
                p = torch.nn.functional.softmax(
                    self._remove_margins(logits, margin, keep_left=first, keep_right=last)[stm], dim=-1,
                )
                b_probs.append(p.detach().cpu().numpy())
        return b_punct_probs, b_capit_probs, new_start_word_ids

    @staticmethod
    def _move_acc_probs_to_token_preds(
        pred: List[int], acc_prob: np.ndarray, number_of_probs_to_move: int
    ) -> Tuple[List[int], np.ndarray]:
        """
        ``number_of_probs_to_move`` rows in the beginning are removed from ``acc_prob``. From every remove row the label
        with the largest probability is selected and appended to ``pred``.
        Args:
            pred: list with ready label indices for a query
            acc_prob: numpy array of shape ``[number_of_words_for_which_probabilities_are_accumulated, number_of_labels]``
            number_of_probs_to_move: int
        Returns:
            pred: list with ready label indices for a query
            acc_prob: numpy array of shape
                ``[number_of_words_for_which_probabilities_are_accumulated - number_of_probs_to_move, number_of_labels]``
        """
        if number_of_probs_to_move > acc_prob.shape[0]:
            raise ValueError(
                f"Not enough accumulated probabilities. Number_of_probs_to_move={number_of_probs_to_move} "
                f"acc_prob.shape={acc_prob.shape}"
            )
        if number_of_probs_to_move > 0:
            pred = pred + list(np.argmax(acc_prob[:number_of_probs_to_move], axis=-1))
        acc_prob = acc_prob[number_of_probs_to_move:]
        return pred, acc_prob

    @staticmethod
    def _update_accumulated_probabilities(acc_prob: np.ndarray, update: np.ndarray) -> np.ndarray:
        """
        Args:
            acc_prob: numpy array of shape ``[A, L]``
            update: numpy array of shape ``[A + N, L]``
        Returns:
            numpy array of shape ``[A + N, L]``
        """
        acc_prob = np.concatenate([acc_prob * update[: acc_prob.shape[0]], update[acc_prob.shape[0] :]], axis=0)
        return acc_prob

    def apply_punct_capit_predictions(self, query: str, punct_preds: List[int], capit_preds: List[int]) -> str:
        """
        Restores punctuation and capitalization in ``query``.
        Args:
            query: a string without punctuation and capitalization
            punct_preds: ids of predicted punctuation labels
            capit_preds: ids of predicted capitalization labels
        Returns:
            a query with restored punctuation and capitalization
        """
        query = query.strip().split()
        assert len(query) == len(
            punct_preds
        ), f"len(query)={len(query)} len(punct_preds)={len(punct_preds)}, query[:30]={query[:30]}"
        assert len(query) == len(
            capit_preds
        ), f"len(query)={len(query)} len(capit_preds)={len(capit_preds)}, query[:30]={query[:30]}"
        punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
        capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}
        query_with_punct_and_capit = ''
        for j, word in enumerate(query):
            punct_label = punct_ids_to_labels[punct_preds[j]]
            capit_label = capit_ids_to_labels[capit_preds[j]]

            if capit_label != self._cfg.dataset.pad_label:
                word = word.capitalize()
            query_with_punct_and_capit += word
            if punct_label != self._cfg.dataset.pad_label:
                query_with_punct_and_capit += punct_label
            query_with_punct_and_capit += ' '
        return query_with_punct_and_capit[:-1]

    def get_labels(self, punct_preds: List[int], capit_preds: List[int]) -> str:
        """
        Returns punctuation and capitalization labels in NeMo format (see https://docs.nvidia.com/deeplearning/nemo/
        user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#nemo-data-format).
        Args:
            punct_preds: ids of predicted punctuation labels
            capit_preds: ids of predicted capitalization labels
        Returns:
            labels in NeMo format
        """
        assert len(capit_preds) == len(
            punct_preds
        ), f"len(capit_preds)={len(capit_preds)} len(punct_preds)={len(punct_preds)}"
        punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
        capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}
        result = ''
        for capit_label, punct_label in zip(capit_preds, punct_preds):
            punct_label = punct_ids_to_labels[punct_label]
            capit_label = capit_ids_to_labels[capit_label]
            result += punct_label + capit_label + ' '
        return result[:-1]

    def add_punctuation_capitalization(
        self,
        queries: List[str],
        batch_size: int = None,
        max_seq_length: int = 64,
        step: int = 8,
        margin: int = 16,
        return_labels: bool = False,
    ) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for inference.

        Parameters ``max_seq_length``, ``step``, ``margin`` are for controlling the way queries are split into segments
        which then processed by the model. Parameter ``max_seq_length`` is a length of a segment after tokenization
        including special tokens [CLS] in the beginning and [SEP] in the end of a segment. Parameter ``step`` is shift
        between consequent segments. Parameter ``margin`` is used to exclude negative effect of subtokens near
        borders of segments which have only one side context.

        If segments overlap, probabilities of overlapping predictions are multiplied and then the label with
        corresponding to the maximum probability is selected.

        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
            max_seq_length: maximum sequence length of segment after tokenization.
            step: relative shift of consequent segments into which long queries are split. Long queries are split into
                segments which can overlap. Parameter ``step`` controls such overlapping. Imagine that queries are
                tokenized into characters, ``max_seq_length=5``, and ``step=2``. In such a case query "hello" is
                tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``.
            margin: number of subtokens in the beginning and the end of segments which are not used for prediction
                computation. The first segment does not have left margin and the last segment does not have right
                margin. For example, if input sequence is tokenized into characters, ``max_seq_length=5``,
                ``step=1``, and ``margin=1``, then query "hello" will be tokenized into segments
                ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'],
                ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions
                computation, margins are removed. In the next list, subtokens which logits are not used for final
                predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*],
                ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``.
            return_labels: whether to return labels in NeMo format (see https://docs.nvidia.com/deeplearning/nemo/
                user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#nemo-data-format) instead of queries
                with restored punctuation and capitalization.
        Returns:
            result: text with added capitalization and punctuation or punctuation and capitalization labels
        """
        if len(queries) == 0:
            return []
        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')
        result: List[str] = []
        mode = self.training
        try:
            self.eval()
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size, max_seq_length, step, margin)
            # Predicted labels for queries. List of labels for every query
            all_punct_preds: List[List[int]] = [[] for _ in queries]
            all_capit_preds: List[List[int]] = [[] for _ in queries]
            # Accumulated probabilities (or product of probabilities acquired from different segments) of punctuation
            # and capitalization. Probabilities for words in a query are extracted using `subtokens_mask`. Probabilities
            # for newly processed words are appended to the accumulated probabilities. If probabilities for a word are
            # already present in `acc_probs`, old probabilities are replaced with a product of old probabilities
            # and probabilities acquired from new segment. Segments are processed in an order they appear in an
            # input query. When all segments with a word are processed, a label with the highest probability
            # (or product of probabilities) is chosen and appended to an appropriate list in `all_preds`. After adding
            # prediction to `all_preds`, probabilities for a word are removed from `acc_probs`.
            acc_punct_probs: List[Optional[np.ndarray]] = [None for _ in queries]
            acc_capit_probs: List[Optional[np.ndarray]] = [None for _ in queries]
            d = self.device
            for batch_i, batch in tqdm(
                enumerate(infer_datalayer), total=ceil(len(infer_datalayer.dataset) / batch_size), unit="batch"
            ):
                inp_ids, inp_type_ids, inp_mask, subtokens_mask, start_word_ids, query_ids, is_first, is_last = batch
                punct_logits, capit_logits = self.forward(
                    input_ids=inp_ids.to(d), token_type_ids=inp_type_ids.to(d), attention_mask=inp_mask.to(d),
                )
                _res = self._transform_logit_to_prob_and_remove_margins_and_extract_word_probs(
                    punct_logits, capit_logits, subtokens_mask, start_word_ids, margin, is_first, is_last
                )
                punct_probs, capit_probs, start_word_ids = _res
                for i, (q_i, start_word_id, bpp_i, bcp_i) in enumerate(
                    zip(query_ids, start_word_ids, punct_probs, capit_probs)
                ):
                    for all_preds, acc_probs, b_probs_i in [
                        (all_punct_preds, acc_punct_probs, bpp_i),
                        (all_capit_preds, acc_capit_probs, bcp_i),
                    ]:
                        if acc_probs[q_i] is None:
                            acc_probs[q_i] = b_probs_i
                        else:
                            all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds(
                                all_preds[q_i], acc_probs[q_i], start_word_id - len(all_preds[q_i]),
                            )
                            acc_probs[q_i] = self._update_accumulated_probabilities(acc_probs[q_i], b_probs_i)
            for all_preds, acc_probs in [(all_punct_preds, acc_punct_probs), (all_capit_preds, acc_capit_probs)]:
                for q_i, (pred, prob) in enumerate(zip(all_preds, acc_probs)):
                    if prob is not None:
                        all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds(pred, prob, len(prob))
            for i, query in enumerate(queries):
                result.append(
                    self.get_labels(all_punct_preds[i], all_capit_preds[i])
                    if return_labels
                    else self.apply_punct_capit_predictions(query, all_punct_preds[i], all_capit_preds[i])
                )
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="punctuation_en_bert",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/punctuation_en_bert/versions/1.0.0rc1/files/punctuation_en_bert.nemo",
                description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="punctuation_en_distilbert",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/punctuation_en_distilbert/versions/1.0.0rc1/files/punctuation_en_distilbert.nemo",
                description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        return result

    @property
    def input_module(self):
        return self.bert_model

    @property
    def output_module(self):
        return self
Example #5
0
class IntentSlotClassificationModel(NLPModel):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Setup tokenizer.
        self.setup_tokenizer(cfg.tokenizer)
        self.cfg = cfg
        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Disable setup methods.
            IntentSlotClassificationModel._set_model_restore_state(
                is_being_restored=True)
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds,
                                       cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=self.cfg.language_model.
            pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(self.cfg.language_model.config)
            if self.cfg.language_model.config else None,
            checkpoint_file=self.cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        # Enable setup methods.
        IntentSlotClassificationModel._set_model_restore_state(
            is_being_restored=False)

        # Initialize Classifier.
        self._reconfigure_classifier()

    def _set_defaults_data_desc(self, cfg):
        """
        Method makes sure that cfg.data_desc params are set.
        If not, set's them to "dummy" defaults.
        """
        if not hasattr(cfg, "data_desc"):
            OmegaConf.set_struct(cfg, False)
            cfg.data_desc = {}
            # Intents.
            cfg.data_desc.intent_labels = " "
            cfg.data_desc.intent_label_ids = {" ": 0}
            cfg.data_desc.intent_weights = [1]
            # Slots.
            cfg.data_desc.slot_labels = " "
            cfg.data_desc.slot_label_ids = {" ": 0}
            cfg.data_desc.slot_weights = [1]

            cfg.data_desc.pad_label = "O"
            OmegaConf.set_struct(cfg, True)

    def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds):
        """ Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc. """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = IntentSlotDataDesc(
            data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create({
                'intent_labels_file':
                'intent_labels.csv',
                'slot_labels_file':
                'slot_labels.csv'
            })

        slot_labels_file = os.path.join(data_dir,
                                        cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir,
                                          cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact('class_labels.intent_labels_file',
                               intent_labels_file)
        self.register_artifact('class_labels.slot_labels_file',
                               slot_labels_file)
        OmegaConf.set_struct(cfg, True)

    def _save_label_ids(self, label_ids: Dict[str, int],
                        filename: str) -> None:
        """ Saves label ids map to a file """
        with open(filename, 'w') as out:
            labels, _ = zip(*sorted(label_ids.items(), key=lambda x: x[1]))
            out.write('\n'.join(labels))
            logging.info(f'Labels: {label_ids}')
            logging.info(f'Labels mapping saved to : {out.name}')

    def _reconfigure_classifier(self):
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(num_inputs=2,
                                         weights=[
                                             self.cfg.intent_loss_weight,
                                             1.0 - self.cfg.intent_loss_weight
                                         ])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

    def update_data_dir_for_training(self, data_dir: str, train_ds,
                                     validation_ds) -> None:
        """
        Update data directory and get data stats with Data Descriptor.
        Also, reconfigures the classifier - to cope with data with e.g. different number of slots.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir
        # Update configuration with new data.
        self._set_data_desc_to_cfg(self.cfg, data_dir, train_ds, validation_ds)
        # Reconfigure the classifier for different settings (number of intents, slots etc.).
        self._reconfigure_classifier()

    def update_data_dir_for_testing(self, data_dir) -> None:
        """
        Update data directory.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        intent_logits, slot_logits = self.classifier(
            hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        intent_preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(intent_preds, intent_labels)
        # slots

        subtokens_mask = subtokens_mask > 0.5
        slot_preds = torch.argmax(slot_logits, axis=-1)
        self.slot_classification_report.update(slot_preds[subtokens_mask],
                                               slot_labels[subtokens_mask])

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
            'intent_preds': intent_preds,
            'intent_labels': intent_labels,
            'slot_preds': slot_preds,
            'slot_labels': slot_labels,
            'input': input_ids,
            'subtokens_mask': subtokens_mask,
        }

    @staticmethod
    def get_continuous_slots(slot_ids, utterance_tokens):
        """
        Extract continuous spans of slot_ids
        Args:
            Slot_ids: list of str representing slot of each word token
            For instance, 'O', 'email_address', 'email_address', 'email_address', 'O', 'O', 'O', 'O']
            Corresponds to ['enter', 'atdfd@yahoo', 'dot', 'com', 'into', 'my', 'contact', 'list']
        Returns:
            list of str where each element is a slot name-value pair
            e.g. ['email_address(atdfd@yahoo dot com)']

        """
        slot_id_stack = []
        position_stack = []
        for i, slot_id in enumerate(slot_ids):
            if not slot_id_stack or slot_id != slot_id_stack[-1]:
                slot_id_stack.append(slot_id)
                position_stack.append([])
            position_stack[-1].append(i)

        slot_id_to_start_and_exclusive_end = {
            slot_id_stack[i]:
            [position_stack[i][0], position_stack[i][-1] + 1]
            for i in range(len(position_stack)) if slot_id_stack[i] != 'O'
        }

        slot_to_words = {
            slot: ' '.join(utterance_tokens[position[0]:position[1]])
            for slot, position in slot_id_to_start_and_exclusive_end.items()
        }

        slot_name_and_values = [
            "{}({})".format(slot, value)
            for slot, value in slot_to_words.items()
        ]

        return slot_name_and_values

    def get_unified_metrics(self, outputs):
        slot_preds = []
        slot_labels = []
        subtokens_mask = []
        inputs = []
        intent_preds = []
        intent_labels = []

        for output in outputs:
            slot_preds += output['slot_preds']
            slot_labels += output["slot_labels"]
            subtokens_mask += output["subtokens_mask"]
            inputs += output["input"]
            intent_preds += output["intent_preds"]
            intent_labels += output["intent_labels"]

        ground_truth_labels = self.convert_intent_ids_to_intent_names(
            intent_labels)
        generated_labels = self.convert_intent_ids_to_intent_names(
            intent_preds)

        predicted_slots = self.mask_unused_subword_slots(
            slot_preds, subtokens_mask)
        ground_truth_slots = self.mask_unused_subword_slots(
            slot_labels, subtokens_mask)

        all_generated_slots = []
        all_ground_truth_slots = []
        all_utterances = []

        for i in range(len(predicted_slots)):
            utterance = self.tokenizer.tokenizer.decode(
                inputs[i], skip_special_tokens=True)
            utterance_tokens = utterance.split()
            ground_truth_slot_names = ground_truth_slots[i].split()
            predicted_slot_names = predicted_slots[i].split()
            if len(utterance_tokens) != len(ground_truth_slot_names):
                # fix the bug that abc@xyz get tokenized to 3 tokens and @xyz to 2 tokens
                utterance_tokens = IntentSlotClassificationModel.join_tokens_containing_at_sign(
                    utterance_tokens, ground_truth_slot_names)
            processed_ground_truth_slots = IntentSlotClassificationModel.get_continuous_slots(
                ground_truth_slot_names, utterance_tokens)
            processed_predicted_slots = IntentSlotClassificationModel.get_continuous_slots(
                predicted_slot_names, utterance_tokens)

            all_generated_slots.append(processed_predicted_slots)
            all_ground_truth_slots.append(processed_ground_truth_slots)
            all_utterances.append(' '.join(utterance_tokens))

        os.makedirs(self.cfg.dataset.dialogues_example_dir, exist_ok=True)
        filename = os.path.join(self.cfg.dataset.dialogues_example_dir,
                                "predictions.jsonl")

        IntentSlotMetrics.save_predictions(
            filename,
            generated_labels,
            all_generated_slots,
            ground_truth_labels,
            all_ground_truth_slots,
            ['' for i in range(len(generated_labels))],
            ['' for i in range(len(generated_labels))],
            all_utterances,
        )

        slot_precision, slot_recall, slot_f1, slot_joint_goal_accuracy = IntentSlotMetrics.get_slot_filling_metrics(
            all_generated_slots, all_ground_truth_slots)

        return slot_precision, slot_recall, slot_f1, slot_joint_goal_accuracy

    @staticmethod
    def join_tokens_containing_at_sign(utterance_tokens, slot_names):
        """
        assumes utterance contains only one @ sign
        """
        target_length = len(slot_names)
        current_length = len(utterance_tokens)
        diff = current_length - target_length
        at_sign_positions = [
            index for index, token in enumerate(utterance_tokens)
            if token == "@"
        ]
        if len(at_sign_positions) > 1:
            raise ValueError(
                "Current method does not support utterances with more than 1 @ sign ({} encountered), please extend this method for utterance {} with slot names {}"
                .format(len(at_sign_positions), utterance_tokens, slot_names))
        elif diff == 1:
            new_tokens = []
            for index, token in enumerate(utterance_tokens):
                if utterance_tokens[index - 1] == "@":
                    new_tokens[-1] += token
                else:
                    new_tokens.append(token)

        elif diff == 2:
            new_tokens = []
            for index, token in enumerate(utterance_tokens[:-1]):
                if utterance_tokens[index - 1] == "@" or token == "@":
                    new_tokens[-1] += token
                else:
                    new_tokens.append(token)

        elif diff == 3:
            new_tokens = []
            for index, token in enumerate(utterance_tokens[:-1]):
                if utterance_tokens[index + 1] == "@" or utterance_tokens[
                        index - 1] == "@" or token == "@":
                    new_tokens[-1] += token
                else:
                    new_tokens.append(token)
        else:
            raise ValueError(
                "Difference of more than 3 ({}) encountered. please extend this method for utterance {} with slots {}"
                .format(diff, utterance_tokens, slot_names))

        return new_tokens

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """

        (
            unified_slot_precision,
            unified_slot_recall,
            unified_slot_f1,
            unified_slot_joint_goal_accuracy,
        ) = self.get_unified_metrics(outputs)

        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute(
        )
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute(
        )
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)
        self.log('unified_slot_precision', unified_slot_precision)
        self.log('unified_slot_recall', unified_slot_recall)
        self.log('unified_slot_f1', unified_slot_f1)
        self.log('unified_slot_joint_goal_accuracy',
                 unified_slot_joint_goal_accuracy)

        self.intent_classification_report.reset()
        self.slot_classification_report.reset()

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
            'unified_slot_precision': unified_slot_precision,
            'unified_slot_recall': unified_slot_recall,
            'unified_slot_f1': unified_slot_f1,
            'unified_slot_joint_goal_accuracy':
            unified_slot_joint_goal_accuracy,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, dataset_split='train')

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, dataset_split='dev')

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, dataset_split='test')

    def _setup_dataloader_from_config(self, cfg: DictConfig,
                                      dataset_split: str):
        data_processor = DialogueAssistantDataProcessor(
            self.data_dir, self.tokenizer)

        dataset = DialogueBERTDataset(
            dataset_split,
            data_processor,
            self.tokenizer,
            self.cfg.
            dataset,  # this is the model.dataset cfg, which is diff from train_ds cfg etc
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def _setup_infer_dataloader(self, queries: List[str],
                                test_ds) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: text
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """

        dataset = IntentSlotInferenceDataset(tokenizer=self.tokenizer,
                                             queries=queries,
                                             max_seq_length=-1,
                                             do_lower_case=False)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=test_ds.batch_size,
            shuffle=test_ds.shuffle,
            num_workers=test_ds.num_workers,
            pin_memory=test_ds.pin_memory,
            drop_last=test_ds.drop_last,
        )

    def update_data_dirs(self, data_dir: str, dialogues_example_dir: str):
        """
        Update data directories

        Args:
            data_dir: path to data directory
            dialogues_example_dir: path to preprocessed dialogues example directory, if not exists will be created.
        """
        if not os.path.exists(data_dir):
            raise ValueError(f"{data_dir} is not found")
        self.cfg.dataset.data_dir = data_dir
        self.cfg.dataset.dialogues_example_dir = dialogues_example_dir
        logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
        logging.info(
            f'Setting model.dataset.dialogues_example_dir to {dialogues_example_dir}.'
        )

    def predict_from_examples(self, queries: List[str],
                              test_ds) -> List[List[str]]:
        """
        Get prediction for the queries (intent and slots)
        Args:
            queries: text sequences
            test_ds: Dataset configuration section.
        Returns:
            predicted_intents, predicted_slots: model intent and slot predictions
        """

        predicted_intents = []
        predicted_slots = []
        mode = self.training

        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # Switch model to evaluation mode
        self.eval()
        self.to(device)

        # Dataset.
        infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

        for batch in infer_datalayer:
            input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

            intent_logits, slot_logits = self.forward(
                input_ids=input_ids.to(device),
                token_type_ids=input_type_ids.to(device),
                attention_mask=input_mask.to(device),
            )

            # predict intents
            intent_preds = tensor2list(torch.argmax(intent_logits, axis=-1))
            predicted_intents += self.convert_intent_ids_to_intent_names(
                intent_preds)

            # predict slots
            slot_preds = torch.argmax(slot_logits, axis=-1)
            predicted_slots += self.mask_unused_subword_slots(
                slot_preds, subtokens_mask)

        # set mode back to its original value
        self.train(mode=mode)

        return predicted_intents, predicted_slots

    def convert_intent_ids_to_intent_names(self, intent_preds):
        # Retrieve intent and slot vocabularies from configuration.
        intent_labels = self.cfg.data_desc.intent_labels

        predicted_intents = []

        # convert numerical outputs to Intent and Slot labels from the dictionaries
        for intent_num in intent_preds:
            # if intent_num < len(intent_labels):
            predicted_intents.append(intent_labels[int(intent_num)])
            # else:
            #     # should not happen
            #     predicted_intents.append("Unknown Intent")
        return predicted_intents

    def mask_unused_subword_slots(self, slot_preds, subtokens_mask):
        # Retrieve intent and slot vocabularies from configuration.
        slot_labels = self.cfg.data_desc.slot_labels
        predicted_slots = []
        for slot_preds_query, mask_query in zip(slot_preds, subtokens_mask):
            query_slots = ''
            for slot, mask in zip(slot_preds_query, mask_query):
                if mask == 1:
                    # if slot < len(slot_labels):
                    query_slots += slot_labels[int(slot)] + ' '
                    # else:
                    #     query_slots += 'Unknown_slot '
            predicted_slots.append(query_slots.strip())
        return predicted_slots

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description=
            "This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result
Example #6
0
class IntentSlotClassificationModel(NLPModel):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """

        self.data_dir = cfg.data_dir
        self.max_seq_length = cfg.language_model.max_seq_length

        self.data_desc = IntentSlotDataDesc(
            data_dir=cfg.data_dir,
            modes=[cfg.train_ds.prefix, cfg.validation_ds.prefix])

        self._setup_tokenizer(cfg.tokenizer)
        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # initialize Bert model

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=self.data_desc.num_intents,
            num_slots=self.data_desc.num_slots,
            dropout=cfg.head.fc_dropout,
            num_layers=cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[cfg.intent_loss_weight, 1.0 - cfg.intent_loss_weight])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=self.data_desc.num_intents,
            label_ids=self.data_desc.intents_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=self.data_desc.num_slots,
            label_ids=self.data_desc.slots_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        intent_logits, slot_logits = self.classifier(
            hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(preds, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
        }

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute(
        )
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute(
        )
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            tokenizer_model=cfg.tokenizer_model,
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            vocab_file=cfg.vocab_file,
        )
        self.tokenizer = tokenizer

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        input_file = f'{self.data_dir}/{cfg.prefix}.tsv'
        slot_file = f'{self.data_dir}/{cfg.prefix}_slots.tsv'

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f'{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files.')

        dataset = IntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.data_desc.pad_label,
            ignore_extra_tokens=self._cfg.ignore_extra_tokens,
            ignore_start_end=self._cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description=
            "This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result
Example #7
0
class MultiLabelIntentSlotClassificationModel(IntentSlotClassificationModel):
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ 
        Initializes BERT Joint Intent and Slot model.

        Args: 
            cfg: configuration object
            trainer: trainer for Pytorch Lightning
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Optimal Threshold
        self.threshold = 0.5
        self.max_f1 = 0

        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds, cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize Classifier.
        self._reconfigure_classifier()

    def _set_data_desc_to_cfg(
        self, cfg: DictConfig, data_dir: str, train_ds: DictConfig, validation_ds: DictConfig
    ) -> None:
        """ 
        Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. 
        
        Args: 
            cfg: configuration object
            data_dir: data directory 
            train_ds: training dataset file name
            validation_ds: validation dataset file name

        Returns:
            None
        """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = MultiLabelIntentSlotDataDesc(data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create(
                {"intent_labels_file": "intent_labels.csv", "slot_labels_file": "slot_labels.csv",}
            )

        slot_labels_file = os.path.join(data_dir, cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir, cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact("class_labels.intent_labels_file", intent_labels_file)
        self.register_artifact("class_labels.slot_labels_file", slot_labels_file)
        OmegaConf.set_struct(cfg, True)

    def _reconfigure_classifier(self) -> None:
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == "weighted_loss":
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = BCEWithLogitsLoss(logits_ndim=2, pos_weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = BCEWithLogitsLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2, weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight],
        )

        # setup to track metrics
        self.intent_classification_report = MultiLabelClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode="micro",
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode="micro",
        )

    def validation_step(self, batch, batch_idx) -> None:
        """
        Validation Loop. Pytorch Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.

        Args:
            batch: batches of data from DataLoader
            batch_idx: batch idx from DataLoader

        Returns: 
            None
        """
        (input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels,) = batch
        intent_logits, slot_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask,
        )

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits, labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits, labels=slot_labels, loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        intent_probabilities = torch.round(torch.sigmoid(intent_logits))

        self.intent_classification_report.update(intent_probabilities, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            "val_loss": val_loss,
            "intent_tp": self.intent_classification_report.tp,
            "intent_fn": self.intent_classification_report.fn,
            "intent_fp": self.intent_classification_report.fp,
            "slot_tp": self.slot_classification_report.tp,
            "slot_fn": self.slot_classification_report.fn,
            "slot_fp": self.slot_classification_report.fp,
        }

    def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader:
        """
        Creates the DataLoader from the configuration object

        Args:
            cfg: configuration object
        
        Returns:
            DataLoader for model's data
        """

        input_file = f"{self.data_dir}/{cfg.prefix}.tsv"
        slot_file = f"{self.data_dir}/{cfg.prefix}_slots.tsv"
        intent_dict_file = self.data_dir + "/dict.intents.csv"

        lines = open(intent_dict_file, "r").readlines()
        lines = [line.strip() for line in lines if line.strip()]
        num_intents = len(lines)

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f"{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files."
            )

        dataset = MultiLabelIntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            num_intents=num_intents,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.cfg.data_desc.pad_label,
            ignore_extra_tokens=self.cfg.ignore_extra_tokens,
            ignore_start_end=self.cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def prediction_probabilities(self, queries: List[str], test_ds: DictConfig) -> npt.NDArray:
        """
        Get prediction probabilities for the queries (intent and slots)

        Args:
            queries: text sequences
            test_ds: Dataset configuration section.

        Returns:
            numpy array of intent probabilities
        """

        probabilities = []

        mode = self.training
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            # Dataset.
            infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents for these examples
                probabilities.append(torch.sigmoid(intent_logits).detach().cpu().numpy())

            probabilities = np.concatenate(probabilities)

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return probabilities

    def optimize_threshold(self, test_ds: DictConfig, file_name: str) -> None:
        """
        Set the optimal threshold of the model from performance on validation set. This threshold is used to round the 
        logits to 0 or 1. 

        Args:
            test_ds: location of test dataset
            file_name: name of input file to retrieve validation set

        Returns:
            None
        """

        input_file = f"{self.data_dir}/{file_name}.tsv"

        with open(input_file, "r") as f:
            input_lines = f.readlines()[1:]  # Skipping headers at index 0

        dataset = list(input_lines)

        metrics_labels, sentences = [], []

        for input_line in dataset:
            sentence = input_line.strip().split("\t")[0]
            sentences.append(sentence)
            parts = input_line.strip().split("\t")[1:][0]
            parts = list(map(int, parts.split(",")))
            parts = [1 if label in parts else 0 for label in range(len(self.cfg.data_desc.intent_labels))]
            metrics_labels.append(parts)

        # Retrieve class probabilities for each sentence
        intent_probabilities = self.prediction_probabilities(sentences, test_ds)

        metrics_dict = {}
        # Find optimal logits rounding threshold for intents
        for i in np.arange(0.5, 0.96, 0.01):
            predictions = (intent_probabilities >= i).tolist()
            precision = precision_score(metrics_labels, predictions, average='micro')
            recall = recall_score(metrics_labels, predictions, average='micro')
            f1 = f1_score(metrics_labels, predictions, average='micro')
            metrics_dict[i] = [precision, recall, f1]

        max_precision = max(metrics_dict, key=lambda x: metrics_dict[x][0])
        max_recall = max(metrics_dict, key=lambda x: metrics_dict[x][1])
        max_f1_score = max(metrics_dict, key=lambda x: metrics_dict[x][2])

        logging.info(
            f'Best Threshold for F1-Score: {max_f1_score}, [Precision, Recall, F1-Score]: {metrics_dict[max_f1_score]}'
        )
        logging.info(
            f'Best Threshold for Precision: {max_precision}, [Precision, Recall, F1-Score]: {metrics_dict[max_precision]}'
        )
        logging.info(
            f'Best Threshold for Recall: {max_recall}, [Precision, Recall, F1-Score]: {metrics_dict[max_recall]}'
        )

        if metrics_dict[max_f1_score][2] > self.max_f1:
            self.max_f1 = metrics_dict[max_f1_score][2]

            logging.info(f'Setting Threshold to: {max_f1_score}')

            self.threshold = max_f1_score

    def predict_from_examples(
        self, queries: List[str], test_ds: DictConfig, threshold: float = None
    ) -> Tuple[List[List[Tuple[str, float]]], List[str], List[List[int]]]:
        """
        Get prediction for the queries (intent and slots)


        Args:
            queries: text sequences
            test_ds: Dataset configuration section.
            threshold: Threshold for rounding prediction logits
        
        Returns:
            predicted_intents: model intent predictions with their probabilities
                Example:  [[('flight', 0.84)], [('airfare', 0.54), 
                            ('flight', 0.73), ('meal', 0.24)]]
            predicted_slots: model slot predictions
                Example:  ['O B-depart_date.month_name B-depart_date.day_number',
                           'O O B-flight_stop O O O']

            predicted_vector: model intent predictions for each individual query. Binary values within each list 
                indicate whether a class is prediced for the given query (1 for True, 0 for False)
                Example: [[1,0,0,0,0,0], [0,0,1,0,0,0]]
        """
        predicted_intents = []

        if threshold is None:
            threshold = self.threshold
        logging.info(f'Using threshold = {threshold}')

        predicted_slots = []
        predicted_vector = []

        mode = self.training
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"

            # Retrieve intent and slot vocabularies from configuration.
            intent_labels = self.cfg.data_desc.intent_labels
            slot_labels = self.cfg.data_desc.slot_labels

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            # Dataset.
            infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents and slots for these examples
                # intents
                intent_preds = tensor2list(torch.sigmoid(intent_logits))
                # convert numerical outputs to Intent and Slot labels from the dictionaries
                for intents in intent_preds:
                    intent_lst = []
                    temp_list = []
                    for intent_num, probability in enumerate(intents):
                        if probability >= threshold:
                            intent_lst.append((intent_labels[int(intent_num)], round(probability, 2)))
                            temp_list.append(1)
                        else:
                            temp_list.append(0)

                    predicted_vector.append(temp_list)
                    predicted_intents.append(intent_lst)

                # slots
                slot_preds = torch.argmax(slot_logits, axis=-1)
                temp_slots_preds = []

                for slot_preds_query, mask_query in zip(slot_preds, subtokens_mask):
                    temp_slots = ""
                    query_slots = ""
                    for slot, mask in zip(slot_preds_query, mask_query):
                        if mask == 1:
                            if slot < len(slot_labels):
                                query_slots += slot_labels[int(slot)] + " "
                                temp_slots += f"{slot} "
                            else:
                                query_slots += "Unknown_slot "
                                temp_slots += "0 "
                    predicted_slots.append(query_slots.strip())
                    temp_slots_preds.append(temp_slots)

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return predicted_intents, predicted_slots, predicted_vector

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        To be added
        """
        result = []
        return result