Пример #1
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""
        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        self.class_weights = None

        super().__init__(cfg=cfg, trainer=trainer)

        self.classifier = SequenceClassifier(
            hidden_size=self.hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        self.create_loss_module()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            num_classes=cfg.dataset.num_classes, mode='micro', dist_sync_on_step=True
        )

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        if 'class_labels' in cfg and 'class_labels_file' in cfg.class_labels and cfg.class_labels.class_labels_file:
            self.register_artifact('class_labels.class_labels_file', cfg.class_labels.class_labels_file)
Пример #2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer,
                                                        add_prefix_space=True)
        super().__init__(cfg=cfg, trainer=trainer)
        self.num_labels = len(constants.ALL_TAG_LABELS)
        self.mode = cfg.get('mode', 'joint')

        self.model = AutoModelForTokenClassification.from_pretrained(
            cfg.transformer, num_labels=self.num_labels)
        self.transformer_name = cfg.transformer
        self.max_sequence_len = cfg.get('max_sequence_len',
                                        self._tokenizer.model_max_length)

        # Loss Functions
        self.loss_fct = nn.CrossEntropyLoss(
            ignore_index=constants.LABEL_PAD_TOKEN_ID)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            self.num_labels,
            constants.LABEL_IDS,
            mode='micro',
            dist_sync_on_step=True)

        # Language
        self.lang = cfg.get('lang', None)
Пример #3
0
    def __init__(
        self,
        cfg: DictConfig,
        trainer: Trainer = None,
    ):

        self.cfg = cfg
        self.data_prepared = False

        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)

        if self.cfg.library == "huggingface":
            self.language_model = AutoModelWithLMHead.from_pretrained(
                cfg.language_model.pretrained_model_name)
            self.language_model.resize_token_embeddings(
                len(self.tokenizer.tokenizer))
        elif self.cfg.library == "megatron":
            self.language_model = MegatronGPTModel.restore_from(
                cfg.language_model.lm_checkpoint, trainer=trainer)
            # 1 corresponds to intent slot; 0 corresponds to squad
            self.prompt_tags = [1, 0] if 'prompt_table' in dir(
                self.language_model) else []
            if hasattr(self.language_model, 'prompt_table'):
                self.language_model.prompt_tuning_param_freeze_and_optimizer_setup(
                )

            # Init all new prompts
            for idx, tag in enumerate(cfg.new_prompt_tags):
                self.prompt_tags.append(tag)
                init_method = cfg.new_prompt_init_methods[idx]
                if init_method == "text":
                    init_text = cfg.new_prompt_init_text[idx]
                    self.language_model.init_prompt_from_text(tag, init_text)
                elif init_method == 'random':
                    self.language_model.init_prompt_from_random(tag)
                else:
                    raise ValueError(
                        f'\n Soft prompt init method {init_method} is not recognized, please use text or random'
                    )

        all_labels = list(
            self._train_dl.dataset.all_possible_labels.union(
                self._validation_dl.dataset.all_possible_labels,
                self._test_dl.dataset.all_possible_labels))
        self.label_to_ids = collections.defaultdict(int)

        for i in range(len(all_labels)):
            self.label_to_ids[all_labels[i]] = i

        self.all_existing_labels = set(self.label_to_ids.keys())

        self.token_to_words = {}
        self.classification_report = ClassificationReport(
            num_classes=len(self.label_to_ids) + 1,
            mode='micro',
            label_ids=self.label_to_ids,
            dist_sync_on_step=True)
        self.eval_mode = cfg.eval_mode
        self.cfg = cfg
Пример #4
0
    def test_classification_report(self):
        classification_report_nemo = ClassificationReport(
            num_classes=self.num_classes, label_ids=self.label_ids)

        preds = torch.Tensor([0, 1, 1, 1, 2, 2, 0])
        labels = torch.Tensor([1, 0, 0, 1, 2, 1, 0])

        tp, fp, fn = classification_report_nemo(preds, labels)

        def __convert_to_tensor(sklearn_metric):
            return torch.Tensor([round(sklearn_metric * 100)])[0]

        for mode in ['macro', 'micro', 'weighted']:

            precision, recall, f1 = classification_report_nemo.get_precision_recall_f1(
                tp, fn, fp, mode)
            pr_sklearn, recall_sklearn, f1_sklearn, _ = precision_recall_fscore_support(
                labels, preds, average=mode)

            self.assertEqual(torch.round(precision),
                             __convert_to_tensor(pr_sklearn),
                             f'wrong precision for {mode}')
            self.assertEqual(torch.round(recall),
                             __convert_to_tensor(recall_sklearn),
                             f'wrong recall for {mode}')
            self.assertEqual(torch.round(f1), __convert_to_tensor(f1_sklearn),
                             f'wrong f1 for {mode}')
Пример #5
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self.class_weights = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.classifier = TokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """

        self.data_dir = cfg.data_dir
        self.max_seq_length = cfg.language_model.max_seq_length

        self.data_desc = IntentSlotDataDesc(
            data_dir=cfg.data_dir,
            modes=[cfg.train_ds.prefix, cfg.validation_ds.prefix])

        self._setup_tokenizer(cfg.tokenizer)
        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # initialize Bert model

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=self.data_desc.num_intents,
            num_slots=self.data_desc.num_slots,
            dropout=cfg.head.fc_dropout,
            num_layers=cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[cfg.intent_loss_weight, 1.0 - cfg.intent_loss_weight])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            self.data_desc.num_intents, self.data_desc.intents_label_ids)
        self.slot_classification_report = ClassificationReport(
            self.data_desc.num_slots, self.data_desc.slots_label_ids)

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)
Пример #7
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
Пример #8
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        class_weights = None
        if cfg.dataset.class_balancing == 'weighted_loss':
            if cfg.train_ds.file_path:
                class_weights = calc_class_weights(cfg.train_ds.file_path,
                                                   cfg.dataset.num_classes)
            else:
                logging.info(
                    'Class_balancing feature is enabled but no train file is given. Calculating the class weights is skipped.'
                )

        if class_weights:
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.loss = CrossEntropyLoss(weight=class_weights)
        else:
            self.loss = CrossEntropyLoss()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            num_classes=cfg.dataset.num_classes,
            mode='micro',
            dist_sync_on_step=True)

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        if 'class_labels' in cfg and 'class_labels_file' in cfg.class_labels and cfg.class_labels.class_labels_file:
            self.register_artifact('class_labels',
                                   cfg.class_labels.class_labels_file)
Пример #9
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.setup_tokenizer(cfg.tokenizer)

        self.class_weights = None

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            nemo_file=self.register_artifact(
                'language_model.nemo_file',
                cfg.language_model.get('nemo_file', None)),
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
            trainer=trainer,
        )

        if cfg.language_model.get('nemo_file', None) is not None:
            hidden_size = self.bert_model.cfg.hidden_size
        else:
            hidden_size = self.bert_model.config.hidden_size

        self.classifier = SequenceClassifier(
            hidden_size=hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        self.create_loss_module()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            num_classes=cfg.dataset.num_classes,
            mode='micro',
            dist_sync_on_step=True)

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        if 'class_labels' in cfg and 'class_labels_file' in cfg.class_labels and cfg.class_labels.class_labels_file:
            self.register_artifact('class_labels.class_labels_file',
                                   cfg.class_labels.class_labels_file)
 def setup(self, stage):
     # setup to track metrics, need to put here
     # as data_parallel_group is initialized when calling `fit, or test function`
     app = AppState()
     self.classification_report = ClassificationReport(
         num_classes=len(self.classes),
         label_ids=self.label_ids,
         mode='micro',
         dist_sync_on_step=True,
         process_group=app.data_parallel_group,
     )
Пример #11
0
    def _reconfigure_classifier(self) -> None:
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == "weighted_loss":
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = BCEWithLogitsLoss(
                logits_ndim=2, pos_weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = BCEWithLogitsLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[
                self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight
            ],
        )

        # setup to track metrics
        self.intent_classification_report = MultiLabelClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode="micro",
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode="micro",
        )
Пример #12
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""

        self._setup_tokenizer(cfg.tokenizer)

        self._cfg = cfg
        self.data_desc = None
        self.update_data_dir(cfg.dataset.data_dir)
        self.setup_loss(class_balancing=self._cfg.dataset.class_balancing)

        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=self._cfg.head.log_softmax,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)
        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
Пример #13
0
class TextClassificationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.setup_tokenizer(cfg.tokenizer)

        self.class_weights = None

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        self.classifier = SequenceClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        self.create_loss_module()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            num_classes=cfg.dataset.num_classes,
            mode='micro',
            dist_sync_on_step=True)

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        if 'class_labels' in cfg and 'class_labels_file' in cfg.class_labels and cfg.class_labels.class_labels_file:
            self.register_artifact('class_labels.class_labels_file',
                                   cfg.class_labels.class_labels_file)

    def create_loss_module(self):
        # create the loss module if it is not yet created by the training data loader
        if not hasattr(self, 'loss'):
            if hasattr(self, 'class_weights') and self.class_weights:
                # You may need to increase the number of epochs for convergence when using weighted_loss
                self.loss = CrossEntropyLoss(weight=self.class_weights)
            else:
                self.loss = CrossEntropyLoss()

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, labels = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)

        train_loss = self.loss(logits=logits, labels=labels)

        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, labels = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)

        val_loss = self.loss(logits=logits, labels=labels)

        preds = torch.argmax(logits, axis=-1)

        tp, fn, fp, _ = self.classification_report(preds, labels)

        return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        if not outputs:
            return {}
        if self.trainer.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean()

        # calculate metrics and classification report
        precision, recall, f1, report = self.classification_report.compute()

        logging.info(f'{prefix}_report: {report}')

        self.log(f'{prefix}_loss', avg_loss, prog_bar=True)
        self.log(f'{prefix}_precision', precision)
        self.log(f'{prefix}_f1', f1)
        self.log(f'{prefix}_recall', recall)

        self.classification_report.reset()

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

        # calculate the class weights to be used in the loss function
        if self.cfg.dataset.class_balancing == 'weighted_loss':
            self.class_weights = calc_class_weights(
                train_data_config.file_path, self.cfg.dataset.num_classes)
        else:
            self.class_weights = None
        # we need to create/update the loss module by using the weights calculated from the training data
        self.create_loss_module()

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or not test_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(
            self, cfg: Dict) -> 'torch.utils.data.DataLoader':
        input_file = cfg.file_path
        if not os.path.exists(input_file):
            raise FileNotFoundError(
                f'{input_file} not found! The data should be be stored in TAB-separated files \n\
                "validation_ds.file_path" and "train_ds.file_path" for train and evaluation respectively. \n\
                Each line of the files contains text sequences, where words are separated with spaces. \n\
                The label of the example is separated with TAB at the end of each line. \n\
                Each line of the files should follow the format: \n\
                [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]')

        dataset = TextClassificationDataset(
            tokenizer=self.tokenizer,
            input_file=input_file,
            max_seq_length=self.dataset_cfg.max_seq_length,
            num_samples=cfg.get("num_samples", -1),
            shuffle=cfg.shuffle,
            use_cache=self.dataset_cfg.use_cache,
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
            collate_fn=dataset.collate_fn,
        )

    @torch.no_grad()
    def classifytext(self,
                     queries: List[str],
                     batch_size: int = 1,
                     max_seq_length: int = -1) -> List[int]:
        """
        Get prediction for the queries
        Args:
            queries: text sequences
            batch_size: batch size to use during inference
            max_seq_length: sequences longer than max_seq_length will get truncated. default -1 disables truncation.
        Returns:
            all_preds: model predictions
        """
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        device = next(self.parameters()).device
        try:
            # Switch model to evaluation mode
            self.eval()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {
                "batch_size": batch_size,
                "num_workers": 3,
                "pin_memory": False
            }
            infer_datalayer = self._setup_infer_dataloader(
                dataloader_cfg, queries, max_seq_length)

            for i, batch in enumerate(infer_datalayer):
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                preds = tensor2list(torch.argmax(logits, axis=-1))
                all_preds.extend(preds)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)
        return all_preds

    def _setup_infer_dataloader(
            self,
            cfg: Dict,
            queries: List[str],
            max_seq_length: int = -1) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            cfg: config dictionary containing data loader params like batch_size, num_workers and pin_memory
            queries: text
            max_seq_length: maximum length of queries, default is -1 for no limit
        Returns:
            A pytorch DataLoader.
        """
        dataset = TextClassificationDataset(tokenizer=self.tokenizer,
                                            queries=queries,
                                            max_seq_length=max_seq_length)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg["batch_size"],
            shuffle=False,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=False,
            collate_fn=dataset.collate_fn,
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass

    @classmethod
    def from_pretrained(cls, name: str):
        pass
Пример #14
0
class PunctuationCapitalizationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
            "capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        capit_logits = self.capit_classifier(hidden_states=hidden_states)
        return punct_logits, capit_logits

    def _make_step(self, batch):
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
        punct_logits, capit_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
        capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
        return loss, punct_logits, capit_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        val_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        test_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory

        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        """Setup training data"""
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
            )

        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            self.register_artifact(
                self._cfg.class_labels.punct_labels_file, self._train_dl.dataset.punct_label_ids_file
            )
            self.register_artifact(
                self._cfg.class_labels.capit_labels_file, self._train_dl.dataset.capit_label_ids_file
            )

            # save label maps to the config
            self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
            self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)

    def setup_validation_data(self, val_data_config: Optional[Dict] = None):
        """
        Setup validaton data

        val_data_config: validation data config
        """
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        # use data_dir specified in the ds_item to run evaluation on multiple datasets
        if 'ds_item' in cfg and cfg.ds_item is not None:
            data_dir = cfg.ds_item
        else:
            data_dir = self._cfg.dataset.data_dir

        text_file = os.path.join(data_dir, cfg.text_file)
        label_file = os.path.join(data_dir, cfg.labels_file)

        dataset = BertPunctuationCapitalizationDataset(
            tokenizer=self.tokenizer,
            text_file=text_file,
            label_file=label_file,
            pad_label=self._cfg.dataset.pad_label,
            punct_label_ids=self._cfg.punct_label_ids,
            capit_label_ids=self._cfg.capit_label_ids,
            max_seq_length=self._cfg.dataset.max_seq_length,
            ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
            ignore_start_end=self._cfg.dataset.ignore_start_end,
            use_cache=self._cfg.dataset.use_cache,
            num_samples=cfg.num_samples,
            punct_label_ids_file=self._cfg.class_labels.punct_labels_file
            if 'class_labels' in self._cfg
            else 'punct_label_ids.csv',
            capit_label_ids_file=self._cfg.class_labels.capit_labels_file
            if 'class_labels' in self._cfg
            else 'capit_label_ids.csv',
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=self._cfg.dataset.drop_last,
        )

    def _setup_infer_dataloader(self, queries: List[str], batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertPunctuationCapitalizationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=self._cfg.dataset.max_seq_length
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    def add_punctuation_capitalization(self, queries: List[str], batch_size: int = None) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for debugging and prototyping.
        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
        Returns:
            result: text with added capitalization and punctuation
        """
        if queries is None or len(queries) == 0:
            return []
        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')

        # We will store the output here
        result = []

        # Model's mode and device
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self = self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # store predictions for all queries in a single list
            all_punct_preds = []
            all_capit_preds = []

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                punct_logits, capit_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                punct_preds = tensor2list(torch.argmax(punct_logits, axis=-1)[subtokens_mask])
                capit_preds = tensor2list(torch.argmax(capit_logits, axis=-1)[subtokens_mask])
                all_punct_preds.extend(punct_preds)
                all_capit_preds.extend(capit_preds)

            queries = [q.strip().split() for q in queries]
            queries_len = [len(q) for q in queries]

            if sum(queries_len) != len(all_punct_preds) or sum(queries_len) != len(all_capit_preds):
                raise ValueError('Pred and words must have the same length')

            punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
            capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}

            start_idx = 0
            end_idx = 0
            for query in queries:
                end_idx += len(query)

                # extract predictions for the current query from the list of all predictions
                punct_preds = all_punct_preds[start_idx:end_idx]
                capit_preds = all_capit_preds[start_idx:end_idx]
                start_idx = end_idx

                query_with_punct_and_capit = ''
                for j, word in enumerate(query):
                    punct_label = punct_ids_to_labels[punct_preds[j]]
                    capit_label = capit_ids_to_labels[capit_preds[j]]

                    if capit_label != self._cfg.dataset.pad_label:
                        word = word.capitalize()
                    query_with_punct_and_capit += word
                    if punct_label != self._cfg.dataset.pad_label:
                        query_with_punct_and_capit += punct_label
                    query_with_punct_and_capit += ' '

                result.append(query_with_punct_and_capit.strip())
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_BERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_BERT.nemo",
                description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="Punctuation_Capitalization_with_DistilBERT",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Punctuation_Capitalization_with_DistilBERT.nemo",
                description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        """
        Unlike other models' export() this one creates 5 output files, not 3:
        punct_<output> - fused punctuation model (BERT+PunctuationClassifier)
        capit_<output> - fused capitalization model (BERT+CapitalizationClassifier)
        bert_<output> - common BERT neural net
        punct_classifier_<output> - Punctuation Classifier neural net
        capt_classifier_<output> - Capitalization Classifier neural net
        """
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " PunctuationCapitalizationModel consists of three separate models with different"
                " inputs and outputs."
            )

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output), 'punct_classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Punctuation Classifier exported to ONNX'
        punct_classifier_onnx = self.punct_classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output3 = os.path.join(os.path.dirname(output), 'capit_classifier_' + os.path.basename(output))
        output3_descr = qual_name + ' Capitalization Classifier exported to ONNX'
        capit_classifier_onnx = self.capit_classifier.export(
            output3,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        punct_output_model = attach_onnx_to_onnx(bert_model_onnx, punct_classifier_onnx, "PTCL")
        output4 = os.path.join(os.path.dirname(output), 'punct_' + os.path.basename(output))
        output4_descr = qual_name + ' Punctuation BERT+Classifier exported to ONNX'
        onnx.save(punct_output_model, output4)
        capit_output_model = attach_onnx_to_onnx(bert_model_onnx, capit_classifier_onnx, "CPCL")
        output5 = os.path.join(os.path.dirname(output), 'capit_' + os.path.basename(output))
        output5_descr = qual_name + ' Capitalization BERT+Classifier exported to ONNX'
        onnx.save(capit_output_model, output5)
        return (
            [output1, output2, output3, output4, output5],
            [output1_descr, output2_descr, output3_descr, output4_descr, output5_descr],
        )
Пример #15
0
class DuplexTaggerModel(NLPModel):
    """
    Transformer-based (duplex) tagger model for TN/ITN.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer,
                                                        add_prefix_space=True)
        super().__init__(cfg=cfg, trainer=trainer)
        self.num_labels = len(constants.ALL_TAG_LABELS)
        self.model = AutoModelForTokenClassification.from_pretrained(
            cfg.transformer, num_labels=self.num_labels)
        self.transformer_name = cfg.transformer

        # Loss Functions
        self.loss_fct = nn.CrossEntropyLoss(
            ignore_index=constants.LABEL_PAD_TOKEN_ID)

        # setup to track metrics
        label_ids = {l: idx for idx, l in enumerate(constants.ALL_TAG_LABELS)}
        self.classification_report = ClassificationReport(
            self.num_labels, label_ids, mode='micro', dist_sync_on_step=True)

        # Language
        self.lang = cfg.get('lang', None)

    # Training
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        num_labels = self.num_labels

        # Apply Transformer
        tag_logits = self.model(batch['input_ids'],
                                batch['attention_mask']).logits

        # Loss
        train_loss = self.loss_fct(tag_logits.view(-1, num_labels),
                                   batch['labels'].view(-1))

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)
        return {'loss': train_loss, 'lr': lr}

    # Validation and Testing
    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        # Apply Transformer
        tag_logits = self.model(batch['input_ids'],
                                batch['attention_mask']).logits
        tag_preds = torch.argmax(tag_logits, dim=2)

        # Update classification_report
        predictions, labels = tag_preds.tolist(), batch['labels'].tolist()
        for prediction, label in zip(predictions, labels):
            cur_preds = [
                p for (p, l) in zip(prediction, label)
                if l != constants.LABEL_PAD_TOKEN_ID
            ]
            cur_labels = [
                l for (p, l) in zip(prediction, label)
                if l != constants.LABEL_PAD_TOKEN_ID
            ]
            self.classification_report(
                torch.tensor(cur_preds).to(self.device),
                torch.tensor(cur_labels).to(self.device))

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        # calculate metrics and classification report
        precision, _, _, report = self.classification_report.compute()

        logging.info(report)

        self.log('val_token_precision', precision)

        self.classification_report.reset()

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    # Functions for inference
    @torch.no_grad()
    def _infer(self, sents: List[List[str]], inst_directions: List[str]):
        """ Main function for Inference
        Args:
            sents: A list of inputs tokenized by a basic tokenizer.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns:
            all_tag_preds: A list of list where each list contains the raw tag predictions for the corresponding input.
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input.
        """
        self.eval()

        # Append prefix
        texts = []
        for ix, sent in enumerate(sents):
            if inst_directions[ix] == constants.INST_BACKWARD:
                prefix = constants.ITN_PREFIX
            if inst_directions[ix] == constants.INST_FORWARD:
                prefix = constants.TN_PREFIX
            texts.append([prefix] + sent)

        # Apply the model
        encodings = self._tokenizer(texts,
                                    is_split_into_words=True,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt')
        logits = self.model(**encodings.to(self.device)).logits
        pred_indexes = torch.argmax(logits, dim=-1).tolist()

        # Extract all_tag_preds for words
        all_tag_preds = []
        batch_size, max_len = encodings['input_ids'].size()
        for ix in range(batch_size):
            raw_tag_preds = [
                constants.ALL_TAG_LABELS[p] for p in pred_indexes[ix][2:]
            ]  # remove first special token and task prefix token
            tag_preds, previous_word_idx = [], None
            word_ids = encodings.word_ids(batch_index=ix)[2:]
            for jx, word_idx in enumerate(word_ids):
                if word_idx is None:
                    continue
                if word_idx != previous_word_idx:
                    tag_preds.append(
                        raw_tag_preds[jx])  # without special token at index 0
                previous_word_idx = word_idx
            all_tag_preds.append(tag_preds)

        # Postprocessing
        all_tag_preds = [
            self.postprocess_tag_preds(words, inst_dir, ps) for words,
            inst_dir, ps in zip(sents, inst_directions, all_tag_preds)
        ]

        # Decoding
        nb_spans, span_starts, span_ends = self.decode_tag_preds(all_tag_preds)

        return all_tag_preds, nb_spans, span_starts, span_ends

    def postprocess_tag_preds(self, words, inst_dir, preds):
        """ Function for postprocessing the raw tag predictions of the model. It
        corrects obvious mistakes in the tag predictions such as a TRANSFORM span
        starts with I_TRANSFORM_TAG (instead of B_TRANSFORM_TAG).

        Args:
            words: The words in the input sentence
            inst_dir: The direction of the instance (i.e., INST_BACKWARD or INST_FORWARD).
            preds: The raw tag predictions

        Returns: The processed raw tag predictions
        """
        final_preds = []
        for ix, p in enumerate(preds):
            # a TRANSFORM span starts with I_TRANSFORM_TAG, change to B_TRANSFORM_TAG
            if p == constants.I_PREFIX + constants.TRANSFORM_TAG:
                if ix == 0 or (not constants.TRANSFORM_TAG
                               in final_preds[ix - 1]):
                    final_preds.append(constants.B_PREFIX +
                                       constants.TRANSFORM_TAG)
                    continue
            # a span has numbers but does not have TRANSFORM tags (for TN)
            if inst_dir == constants.INST_FORWARD:
                if has_numbers(
                        words[ix]) and (not constants.TRANSFORM_TAG in p):
                    final_preds.append(constants.B_PREFIX +
                                       constants.TRANSFORM_TAG)
                    continue
            # Convert B-TASK tag to B-SAME tag
            if p == constants.B_PREFIX + constants.TASK_TAG:
                final_preds.append(constants.B_PREFIX + constants.SAME_TAG)
                continue
            # Default
            final_preds.append(p)
        return final_preds

    def decode_tag_preds(self, tag_preds):
        """ Decoding the raw tag predictions to locate the semiotic spans in the
        input texts.

        Args:
            tag_preds: A list of list where each list contains the raw tag predictions for the corresponding input.

        Returns:
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input.
        """
        nb_spans, span_starts, span_ends = [], [], []
        for i, preds in enumerate(tag_preds):
            cur_nb_spans, cur_span_start = 0, None
            cur_span_starts, cur_span_ends = [], []
            for ix, pred in enumerate(preds + ['EOS']):
                if pred != constants.I_PREFIX + constants.TRANSFORM_TAG:
                    if not cur_span_start is None:
                        cur_nb_spans += 1
                        cur_span_starts.append(cur_span_start)
                        cur_span_ends.append(ix - 1)
                    cur_span_start = None
                if pred == constants.B_PREFIX + constants.TRANSFORM_TAG:
                    cur_span_start = ix
            nb_spans.append(cur_nb_spans)
            span_starts.append(cur_span_starts)
            span_ends.append(cur_span_ends)

        return nb_spans, span_starts, span_ends

    # Functions for processing data
    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for train is created!"
            )
            self._train_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, mode="train")

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!"
            )
            self._validation_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, mode="val")

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or test_data_config.data_path is None:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, mode="test")

    def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str):
        start_time = perf_counter()
        logging.info(f'Creating {mode} dataset')
        input_file = cfg.data_path
        tagger_data_augmentation = cfg.get('tagger_data_augmentation', False)
        dataset = TextNormalizationTaggerDataset(
            input_file,
            self._tokenizer,
            self.transformer_name,
            cfg.mode,
            cfg.do_basic_tokenize,
            tagger_data_augmentation,
            cfg.lang,
            cfg.get('use_cache', False),
            cfg.get('max_insts', -1),
        )
        data_collator = DataCollatorForTokenClassification(self._tokenizer)
        dl = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=cfg.batch_size,
                                         shuffle=cfg.shuffle,
                                         collate_fn=data_collator)
        running_time = perf_counter() - start_time
        logging.info(f'Took {running_time} seconds')
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        result = []
        return result
Пример #16
0
class IntentSlotClassificationModel(NLPModel):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Setup tokenizer.
        self.setup_tokenizer(cfg.tokenizer)

        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Disable setup methods.
            IntentSlotClassificationModel._set_model_restore_state(is_being_restored=True)
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds, cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Enable setup methods.
        IntentSlotClassificationModel._set_model_restore_state(is_being_restored=False)

        # Initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=self.cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file', cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(self.cfg.language_model.config)
            if self.cfg.language_model.config
            else None,
            checkpoint_file=self.cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file),
        )

        # Initialize Classifier.
        self._reconfigure_classifier()

    def _set_defaults_data_desc(self, cfg):
        """
        Method makes sure that cfg.data_desc params are set.
        If not, set's them to "dummy" defaults.
        """
        if not hasattr(cfg, "data_desc"):
            OmegaConf.set_struct(cfg, False)
            cfg.data_desc = {}
            # Intents.
            cfg.data_desc.intent_labels = " "
            cfg.data_desc.intent_label_ids = {" ": 0}
            cfg.data_desc.intent_weights = [1]
            # Slots.
            cfg.data_desc.slot_labels = " "
            cfg.data_desc.slot_label_ids = {" ": 0}
            cfg.data_desc.slot_weights = [1]

            cfg.data_desc.pad_label = "O"
            OmegaConf.set_struct(cfg, True)

    def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds):
        """ Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc. """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = IntentSlotDataDesc(data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create(
                {'intent_labels_file': 'intent_labels.csv', 'slot_labels_file': 'slot_labels.csv'}
            )

        slot_labels_file = os.path.join(data_dir, cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir, cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact(cfg.class_labels.intent_labels_file, intent_labels_file)
        self.register_artifact(cfg.class_labels.slot_labels_file, slot_labels_file)
        OmegaConf.set_struct(cfg, True)

    def _save_label_ids(self, label_ids: Dict[str, int], filename: str) -> None:
        """ Saves label ids map to a file """
        with open(filename, 'w') as out:
            labels, _ = zip(*sorted(label_ids.items(), key=lambda x: x[1]))
            out.write('\n'.join(labels))
            logging.info(f'Labels: {label_ids}')
            logging.info(f'Labels mapping saved to : {out.name}')

    def _reconfigure_classifier(self):
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(logits_ndim=2, weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2, weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight]
        )

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

    def update_data_dir_for_training(self, data_dir: str, train_ds, validation_ds) -> None:
        """
        Update data directory and get data stats with Data Descriptor.
        Also, reconfigures the classifier - to cope with data with e.g. different number of slots.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir
        # Update configuration with new data.
        self._set_data_desc_to_cfg(self.cfg, data_dir, train_ds, validation_ds)
        # Reconfigure the classifier for different settings (number of intents, slots etc.).
        self._reconfigure_classifier()

    def update_data_dir_for_testing(self, data_dir) -> None:
        """
        Update data directory.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        intent_logits, slot_logits = self.classifier(hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits, labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits, labels=slot_labels, loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits, labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits, labels=slot_labels, loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(preds, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
        }

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute()
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute()
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        input_file = f'{self.data_dir}/{cfg.prefix}.tsv'
        slot_file = f'{self.data_dir}/{cfg.prefix}_slots.tsv'

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f'{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files.'
            )

        dataset = IntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.cfg.data_desc.pad_label,
            ignore_extra_tokens=self.cfg.ignore_extra_tokens,
            ignore_start_end=self.cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def _setup_infer_dataloader(self, queries: List[str], test_ds) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: text
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """

        dataset = IntentSlotInferenceDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=-1, do_lower_case=False
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=test_ds.batch_size,
            shuffle=test_ds.shuffle,
            num_workers=test_ds.num_workers,
            pin_memory=test_ds.pin_memory,
            drop_last=test_ds.drop_last,
        )

    def predict_from_examples(self, queries: List[str], test_ds) -> List[List[str]]:
        """
        Get prediction for the queries (intent and slots)
        Args:
            queries: text sequences
            test_ds: Dataset configuration section.
        Returns:
            predicted_intents, predicted_slots: model intent and slot predictions
        """
        predicted_intents = []
        predicted_slots = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'

            # Retrieve intent and slot vocabularies from configuration.
            intent_labels = self.cfg.data_desc.intent_labels
            slot_labels = self.cfg.data_desc.slot_labels

            # Initialize tokenizer.
            # if not hasattr(self, "tokenizer"):
            #    self._setup_tokenizer(self.cfg.tokenizer)
            # Initialize modules.
            # self._reconfigure_classifier()

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            # Dataset.
            infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents and slots for these examples
                # intents
                intent_preds = tensor2list(torch.argmax(intent_logits, axis=-1))

                # convert numerical outputs to Intent and Slot labels from the dictionaries
                for intent_num in intent_preds:
                    if intent_num < len(intent_labels):
                        predicted_intents.append(intent_labels[int(intent_num)])
                    else:
                        # should not happen
                        predicted_intents.append("Unknown Intent")

                # slots
                slot_preds = torch.argmax(slot_logits, axis=-1)

                for slot_preds_query, mask_query in zip(slot_preds, subtokens_mask):
                    query_slots = ''
                    for slot, mask in zip(slot_preds_query, mask_query):
                        if mask == 1:
                            if slot < len(slot_labels):
                                query_slots += slot_labels[int(slot)] + ' '
                            else:
                                query_slots += 'Unknown_slot '
                    predicted_slots.append(query_slots.strip())

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return predicted_intents, predicted_slots

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description="This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result
class IntentSlotClassificationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """

        self.data_dir = cfg.data_dir
        self.max_seq_length = cfg.language_model.max_seq_length

        self.data_desc = IntentSlotDataDesc(
            data_dir=cfg.data_dir,
            modes=[cfg.train_ds.prefix, cfg.validation_ds.prefix])

        self._setup_tokenizer(cfg.tokenizer)
        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=self.data_desc.num_intents,
            num_slots=self.data_desc.num_slots,
            dropout=cfg.head.fc_dropout,
            num_layers=cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[cfg.intent_loss_weight, 1.0 - cfg.intent_loss_weight])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=self.data_desc.num_intents,
            label_ids=self.data_desc.intents_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=self.data_desc.num_slots,
            label_ids=self.data_desc.slots_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory and get data stats with Data Descriptor
        Weights are later used to setup loss

        Args:
            data_dir: path to data directory
        """
        self.data_dir = data_dir
        logging.info(f'Setting model.data_dir to {data_dir}.')

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        intent_logits, slot_logits = self.classifier(
            hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(preds, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
        }

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute(
        )
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute(
        )
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        input_file = f'{self.data_dir}/{cfg.prefix}.tsv'
        slot_file = f'{self.data_dir}/{cfg.prefix}_slots.tsv'

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f'{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files.')

        dataset = IntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.data_desc.pad_label,
            ignore_extra_tokens=self._cfg.ignore_extra_tokens,
            ignore_start_end=self._cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def _setup_infer_dataloader(
            self, queries: List[str],
            batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: text
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """
        dataset = IntentSlotInferenceDataset(tokenizer=self.tokenizer,
                                             queries=queries,
                                             max_seq_length=-1,
                                             do_lower_case=False)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.test_ds.num_workers,
            pin_memory=self._cfg.test_ds.pin_memory,
            drop_last=False,
        )

    def predict_from_examples(self,
                              queries: List[str],
                              batch_size: int = 32) -> List[List[str]]:
        """
        Get prediction for the queries (intent and slots)
        Args:
            queries: text sequences
            batch_size: batch size to use during inference
        Returns:
            predicted_intents, predicted_slots: model intent and slot predictions
        """
        predicted_intents = []
        predicted_slots = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # load intent and slot labels from the dictionary files (user should have them in a data directory)
            intent_labels, slot_labels = IntentSlotDataDesc.intent_slot_dicts(
                self.data_dir)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents and slots for these examples
                # intents
                intent_preds = tensor2list(torch.argmax(intent_logits,
                                                        axis=-1))

                # convert numerical outputs to Intent and Slot labels from the dictionaries
                for intent_num in intent_preds:
                    if intent_num < len(intent_labels):
                        predicted_intents.append(intent_labels[intent_num])
                    else:
                        # should not happen
                        predicted_intents.append("Unknown Intent")

                # slots
                slot_preds = torch.argmax(slot_logits, axis=-1)

                for slot_preds_query, mask_query in zip(
                        slot_preds, subtokens_mask):
                    query_slots = ''
                    for slot, mask in zip(slot_preds_query, mask_query):
                        if mask == 1:
                            if slot < len(slot_labels):
                                query_slots += slot_labels[slot] + ' '
                            else:
                                query_slots += 'Unknown_slot '
                    predicted_slots.append(query_slots.strip())

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return predicted_intents, predicted_slots

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description=
            "This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " IntentSlotClassificationModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "ISC")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])
Пример #18
0
class TextClassificationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self._setup_tokenizer(cfg.tokenizer)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        class_weights = None
        if cfg.dataset.class_balancing == 'weighted_loss':
            if cfg.train_ds.file_path:
                class_weights = calc_class_weights(cfg.train_ds.file_path,
                                                   cfg.dataset.num_classes)
            else:
                logging.info(
                    'Class_balancing feature is enabled but no train file is given. Calculating the class weights is skipped.'
                )

        if class_weights:
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.loss = CrossEntropyLoss(weight=class_weights)
        else:
            self.loss = CrossEntropyLoss()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            cfg.dataset.num_classes)

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            vocab_file=self.register_artifact(
                config_path='tokenizer.vocab_file', src=cfg.vocab_file),
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            tokenizer_model=self.register_artifact(
                config_path='tokenizer.tokenizer_model',
                src=cfg.tokenizer_model),
        )
        self.tokenizer = tokenizer

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, labels = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)

        train_loss = self.loss(logits=logits, labels=labels)

        tensorboard_logs = {
            'train_loss': train_loss,
            'lr': self._optimizer.param_groups[0]['lr']
        }
        return {'loss': train_loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        if self.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        input_ids, input_type_ids, input_mask, labels = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)

        val_loss = self.loss(logits=logits, labels=labels)

        preds = torch.argmax(logits, axis=-1)
        tp, fp, fn = self.classification_report(preds, labels)

        tensorboard_logs = {
            f'{prefix}_loss': val_loss,
            f'{prefix}_tp': tp,
            f'{prefix}_fn': fn,
            f'{prefix}_fp': fp
        }

        return {f'{prefix}_loss': val_loss, 'log': tensorboard_logs}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        if not outputs:
            return {}
        if self.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        avg_loss = torch.stack([x[f'{prefix}_loss'] for x in outputs]).mean()
        # calculate metrics and log classification report
        tp = torch.sum(
            torch.stack([x['log'][f'{prefix}_tp'] for x in outputs]), 0)
        fn = torch.sum(
            torch.stack([x['log'][f'{prefix}_fn'] for x in outputs]), 0)
        fp = torch.sum(
            torch.stack([x['log'][f'{prefix}_fp'] for x in outputs]), 0)
        precision, recall, f1 = self.classification_report.get_precision_recall_f1(
            tp, fn, fp, mode='micro')

        tensorboard_logs = {
            f'{prefix}_loss': avg_loss,
            f'{prefix}_precision': precision,
            f'{prefix}_recall': recall,
            f'{prefix}_f1': f1,
        }
        return {f'{prefix}_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or not test_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(
            self, cfg: Dict) -> 'torch.utils.data.DataLoader':
        input_file = cfg.file_path
        if not os.path.exists(input_file):
            raise FileNotFoundError(
                f'{input_file} not found! The data should be be stored in TAB-separated files \n\
                "validation_ds.file_path" and "train_ds.file_path" for train and evaluation respectively. \n\
                Each line of the files contains text sequences, where words are separated with spaces. \n\
                The label of the example is separated with TAB at the end of each line. \n\
                Each line of the files should follow the format: \n\
                [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]')

        dataset = TextClassificationDataset(
            tokenizer=self.tokenizer,
            input_file=input_file,
            max_seq_length=self.dataset_cfg.max_seq_length,
            num_samples=cfg.get("num_samples", -1),
            shuffle=cfg.shuffle,
            use_cache=self.dataset_cfg.use_cache,
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
            collate_fn=dataset.collate_fn,
        )

    @torch.no_grad()
    def classifytext(self,
                     queries: List[str],
                     batch_size: int = 1,
                     max_seq_length: int = -1) -> List[int]:
        """
        Get prediction for the queries
        Args:
            queries: text sequences
            batch_size: batch size to use during inference
            max_seq_length: sequences longer than max_seq_length will get truncated. default -1 disables truncation.
        Returns:
            all_preds: model predictions
        """
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        device = next(self.parameters()).device
        try:
            # Switch model to evaluation mode
            self.eval()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {
                "batch_size": batch_size,
                "num_workers": 3,
                "pin_memory": False
            }
            infer_datalayer = self._setup_infer_dataloader(
                dataloader_cfg, queries, max_seq_length)

            for i, batch in enumerate(infer_datalayer):
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                preds = tensor2list(torch.argmax(logits, axis=-1))
                all_preds.extend(preds)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)
        return all_preds

    def _setup_infer_dataloader(
            self,
            cfg: Dict,
            queries: List[str],
            max_seq_length: int = -1) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            cfg: config dictionary containing data loader params like batch_size, num_workers and pin_memory
            queries: text
            max_seq_length: maximum length of queries, default is -1 for no limit
        Returns:
            A pytorch DataLoader.
        """
        dataset = TextClassificationDataset(tokenizer=self.tokenizer,
                                            queries=queries,
                                            max_seq_length=max_seq_length)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg["batch_size"],
            shuffle=False,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=False,
            collate_fn=dataset.collate_fn,
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass

    @classmethod
    def from_pretrained(cls, name: str):
        pass

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " TextClassificationModel consists of two separate models with different"
                " inputs and outputs.")

        bert_model_onnx = self.bert_model.export(
            'bert_' + output,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        classifier_onnx = self.classifier.export(
            'classifier_' + output,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "CL")
        onnx.save(output_model, output)
Пример #19
0
class PunctuationCapitalizationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
            "capit_logits": NeuralType(('B', 'T', 'C'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file', cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file),
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        capit_logits = self.capit_classifier(hidden_states=hidden_states)
        return punct_logits, capit_logits

    def _make_step(self, batch):
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, capit_labels = batch
        punct_logits, capit_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask
        )

        punct_loss = self.loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
        capit_loss = self.loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=capit_loss)
        return loss, punct_logits, capit_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        val_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        _, _, _, subtokens_mask, _, punct_labels, capit_labels = batch
        test_loss, punct_logits, capit_logits = self._make_step(batch)

        subtokens_mask = subtokens_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[subtokens_mask]
        punct_labels = punct_labels[subtokens_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        capit_preds = torch.argmax(capit_logits, axis=-1)[subtokens_mask]
        capit_labels = capit_labels[subtokens_mask]
        self.capit_class_report.update(capit_preds, capit_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'capit_tp': self.capit_class_report.tp,
            'capit_fn': self.capit_class_report.fn,
            'capit_fp': self.capit_class_report.fp,
        }

    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

        self.punct_class_report.reset()
        self.capit_class_report.reset()

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for Capitalization task
        capit_precision, capit_recall, capit_f1, capit_report = self.capit_class_report.compute()
        logging.info(f'Capitalization report: {capit_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('capit_precision', capit_precision)
        self.log('capit_f1', capit_f1)
        self.log('capit_recall', capit_recall)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory

        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        """Setup training data"""
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg, "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'punct_labels_file': 'punct_label_ids.csv', 'capit_labels_file': 'capit_label_ids.csv'}
            )

        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            self.register_artifact('class_labels.punct_labels_file', self._train_dl.dataset.punct_label_ids_file)
            self.register_artifact('class_labels.capit_labels_file', self._train_dl.dataset.capit_label_ids_file)

            # save label maps to the config
            self._cfg.punct_label_ids = OmegaConf.create(self._train_dl.dataset.punct_label_ids)
            self._cfg.capit_label_ids = OmegaConf.create(self._train_dl.dataset.capit_label_ids)

    def setup_validation_data(self, val_data_config: Optional[Dict] = None):
        """
        Setup validaton data

        val_data_config: validation data config
        """
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        # use data_dir specified in the ds_item to run evaluation on multiple datasets
        if 'ds_item' in cfg and cfg.ds_item is not None:
            data_dir = cfg.ds_item
        else:
            data_dir = self._cfg.dataset.data_dir

        text_file = os.path.join(data_dir, cfg.text_file)
        label_file = os.path.join(data_dir, cfg.labels_file)

        dataset = BertPunctuationCapitalizationDataset(
            tokenizer=self.tokenizer,
            text_file=text_file,
            label_file=label_file,
            pad_label=self._cfg.dataset.pad_label,
            punct_label_ids=self._cfg.punct_label_ids,
            capit_label_ids=self._cfg.capit_label_ids,
            max_seq_length=self._cfg.dataset.max_seq_length,
            ignore_extra_tokens=self._cfg.dataset.ignore_extra_tokens,
            ignore_start_end=self._cfg.dataset.ignore_start_end,
            use_cache=self._cfg.dataset.use_cache,
            num_samples=cfg.num_samples,
            punct_label_ids_file=self._cfg.class_labels.punct_labels_file
            if 'class_labels' in self._cfg
            else 'punct_label_ids.csv',
            capit_label_ids_file=self._cfg.class_labels.capit_labels_file
            if 'class_labels' in self._cfg
            else 'capit_label_ids.csv',
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=self._cfg.dataset.drop_last,
        )

    def _setup_infer_dataloader(
        self, queries: List[str], batch_size: int, max_seq_length: int, step: int, margin: int,
    ) -> torch.utils.data.DataLoader:
        """
        Setup function for a infer data loader.

        Args:
            model: a ``PunctuationCapitalizationModel`` instance for which data loader is created.
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
            max_seq_length: length of segments into which queries are split. ``max_seq_length`` includes ``[CLS]`` and
                ``[SEP]`` so every segment contains at most ``max_seq_length-2`` tokens from input a query.
            step: number of tokens by which a segment is offset to a previous segment. Parameter ``step`` cannot be greater
                than ``max_seq_length-2``.
            margin: number of tokens near the edge of a segment which label probabilities are not used in final prediction
                computation.
        Returns:
            A pytorch DataLoader.
        """
        if max_seq_length is None:
            max_seq_length = self._cfg.dataset.max_seq_length
        if step is None:
            step = self._cfg.dataset.step
        if margin is None:
            margin = self._cfg.dataset.margin

        dataset = BertPunctuationCapitalizationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=max_seq_length, step=step, margin=margin
        )
        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    @staticmethod
    def _remove_margins(tensor, margin_size, keep_left, keep_right):
        tensor = tensor.detach().clone()
        if not keep_left:
            tensor = tensor[margin_size + 1 :]  # remove left margin and CLS token
        if not keep_right:
            tensor = tensor[: tensor.shape[0] - margin_size - 1]  # remove right margin and SEP token
        return tensor

    def _transform_logit_to_prob_and_remove_margins_and_extract_word_probs(
        self,
        punct_logits: torch.Tensor,
        capit_logits: torch.Tensor,
        subtokens_mask: torch.Tensor,
        start_word_ids: Tuple[int],
        margin: int,
        is_first: Tuple[bool],
        is_last: Tuple[bool],
    ) -> Tuple[List[np.ndarray], List[np.ndarray], List[int]]:
        """
        Applies softmax to get punctuation and capitalization probabilities, applies ``subtokens_mask`` to extract
        probabilities for words from probabilities for tokens, removes ``margin`` probabilities near edges of a segment.
        Left margin of the first segment in a query and right margin of the last segment in a query are not removed.
        Calculates new ``start_word_ids`` taking into the account the margins. If the left margin of a segment is removed
        corresponding start word index is increased by number of words (number of nonzero values in corresponding
        ``subtokens_mask``) in the margin.
        Args:
            punct_logits: a float tensor of shape ``[batch_size, segment_length, number_of_punctuation_labels]``
            capit_logits: a float tensor of shape ``[batch_size, segment_length, number_of_capitalization_labels]``
            subtokens_mask: a float tensor of shape ``[batch_size, segment_length]``
            start_word_ids: indices of segment first words in a query
            margin: number of tokens near edges of a segment which probabilities are discarded
            is_first: is segment the first segment in a query
            is_last: is segment the last segment in a query
        Returns:
            b_punct_probs: list containing ``batch_size`` numpy arrays. The numpy arrays have shapes
                ``[number_of_word_in_this_segment, number_of_punctuation_labels]``. Word punctuation probabilities for
                segments in the batch.
            b_capit_probs: list containing ``batch_size`` numpy arrays. The numpy arrays have shapes
                ``[number_of_word_in_this_segment, number_of_capitalization_labels]``. Word capitalization probabilities for
                segments in the batch.
            new_start_word_ids: indices of segment first words in a query after margin removal
        """
        new_start_word_ids = list(start_word_ids)
        subtokens_mask = subtokens_mask > 0.5
        b_punct_probs, b_capit_probs = [], []
        for i, (first, last, pl, cl, stm) in enumerate(
            zip(is_first, is_last, punct_logits, capit_logits, subtokens_mask)
        ):
            if not first:
                new_start_word_ids[i] += torch.count_nonzero(stm[: margin + 1]).numpy()  # + 1 is for [CLS] token
            stm = self._remove_margins(stm, margin, keep_left=first, keep_right=last)
            for b_probs, logits in [(b_punct_probs, pl), (b_capit_probs, cl)]:
                p = torch.nn.functional.softmax(
                    self._remove_margins(logits, margin, keep_left=first, keep_right=last)[stm], dim=-1,
                )
                b_probs.append(p.detach().cpu().numpy())
        return b_punct_probs, b_capit_probs, new_start_word_ids

    @staticmethod
    def _move_acc_probs_to_token_preds(
        pred: List[int], acc_prob: np.ndarray, number_of_probs_to_move: int
    ) -> Tuple[List[int], np.ndarray]:
        """
        ``number_of_probs_to_move`` rows in the beginning are removed from ``acc_prob``. From every remove row the label
        with the largest probability is selected and appended to ``pred``.
        Args:
            pred: list with ready label indices for a query
            acc_prob: numpy array of shape ``[number_of_words_for_which_probabilities_are_accumulated, number_of_labels]``
            number_of_probs_to_move: int
        Returns:
            pred: list with ready label indices for a query
            acc_prob: numpy array of shape
                ``[number_of_words_for_which_probabilities_are_accumulated - number_of_probs_to_move, number_of_labels]``
        """
        if number_of_probs_to_move > acc_prob.shape[0]:
            raise ValueError(
                f"Not enough accumulated probabilities. Number_of_probs_to_move={number_of_probs_to_move} "
                f"acc_prob.shape={acc_prob.shape}"
            )
        if number_of_probs_to_move > 0:
            pred = pred + list(np.argmax(acc_prob[:number_of_probs_to_move], axis=-1))
        acc_prob = acc_prob[number_of_probs_to_move:]
        return pred, acc_prob

    @staticmethod
    def _update_accumulated_probabilities(acc_prob: np.ndarray, update: np.ndarray) -> np.ndarray:
        """
        Args:
            acc_prob: numpy array of shape ``[A, L]``
            update: numpy array of shape ``[A + N, L]``
        Returns:
            numpy array of shape ``[A + N, L]``
        """
        acc_prob = np.concatenate([acc_prob * update[: acc_prob.shape[0]], update[acc_prob.shape[0] :]], axis=0)
        return acc_prob

    def apply_punct_capit_predictions(self, query: str, punct_preds: List[int], capit_preds: List[int]) -> str:
        """
        Restores punctuation and capitalization in ``query``.
        Args:
            query: a string without punctuation and capitalization
            punct_preds: ids of predicted punctuation labels
            capit_preds: ids of predicted capitalization labels
        Returns:
            a query with restored punctuation and capitalization
        """
        query = query.strip().split()
        assert len(query) == len(
            punct_preds
        ), f"len(query)={len(query)} len(punct_preds)={len(punct_preds)}, query[:30]={query[:30]}"
        assert len(query) == len(
            capit_preds
        ), f"len(query)={len(query)} len(capit_preds)={len(capit_preds)}, query[:30]={query[:30]}"
        punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
        capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}
        query_with_punct_and_capit = ''
        for j, word in enumerate(query):
            punct_label = punct_ids_to_labels[punct_preds[j]]
            capit_label = capit_ids_to_labels[capit_preds[j]]

            if capit_label != self._cfg.dataset.pad_label:
                word = word.capitalize()
            query_with_punct_and_capit += word
            if punct_label != self._cfg.dataset.pad_label:
                query_with_punct_and_capit += punct_label
            query_with_punct_and_capit += ' '
        return query_with_punct_and_capit[:-1]

    def get_labels(self, punct_preds: List[int], capit_preds: List[int]) -> str:
        """
        Returns punctuation and capitalization labels in NeMo format (see https://docs.nvidia.com/deeplearning/nemo/
        user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#nemo-data-format).
        Args:
            punct_preds: ids of predicted punctuation labels
            capit_preds: ids of predicted capitalization labels
        Returns:
            labels in NeMo format
        """
        assert len(capit_preds) == len(
            punct_preds
        ), f"len(capit_preds)={len(capit_preds)} len(punct_preds)={len(punct_preds)}"
        punct_ids_to_labels = {v: k for k, v in self._cfg.punct_label_ids.items()}
        capit_ids_to_labels = {v: k for k, v in self._cfg.capit_label_ids.items()}
        result = ''
        for capit_label, punct_label in zip(capit_preds, punct_preds):
            punct_label = punct_ids_to_labels[punct_label]
            capit_label = capit_ids_to_labels[capit_label]
            result += punct_label + capit_label + ' '
        return result[:-1]

    def add_punctuation_capitalization(
        self,
        queries: List[str],
        batch_size: int = None,
        max_seq_length: int = 64,
        step: int = 8,
        margin: int = 16,
        return_labels: bool = False,
    ) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for inference.

        Parameters ``max_seq_length``, ``step``, ``margin`` are for controlling the way queries are split into segments
        which then processed by the model. Parameter ``max_seq_length`` is a length of a segment after tokenization
        including special tokens [CLS] in the beginning and [SEP] in the end of a segment. Parameter ``step`` is shift
        between consequent segments. Parameter ``margin`` is used to exclude negative effect of subtokens near
        borders of segments which have only one side context.

        If segments overlap, probabilities of overlapping predictions are multiplied and then the label with
        corresponding to the maximum probability is selected.

        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
            max_seq_length: maximum sequence length of segment after tokenization.
            step: relative shift of consequent segments into which long queries are split. Long queries are split into
                segments which can overlap. Parameter ``step`` controls such overlapping. Imagine that queries are
                tokenized into characters, ``max_seq_length=5``, and ``step=2``. In such a case query "hello" is
                tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``.
            margin: number of subtokens in the beginning and the end of segments which are not used for prediction
                computation. The first segment does not have left margin and the last segment does not have right
                margin. For example, if input sequence is tokenized into characters, ``max_seq_length=5``,
                ``step=1``, and ``margin=1``, then query "hello" will be tokenized into segments
                ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'],
                ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions
                computation, margins are removed. In the next list, subtokens which logits are not used for final
                predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*],
                ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``.
            return_labels: whether to return labels in NeMo format (see https://docs.nvidia.com/deeplearning/nemo/
                user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#nemo-data-format) instead of queries
                with restored punctuation and capitalization.
        Returns:
            result: text with added capitalization and punctuation or punctuation and capitalization labels
        """
        if len(queries) == 0:
            return []
        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')
        result: List[str] = []
        mode = self.training
        try:
            self.eval()
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size, max_seq_length, step, margin)
            # Predicted labels for queries. List of labels for every query
            all_punct_preds: List[List[int]] = [[] for _ in queries]
            all_capit_preds: List[List[int]] = [[] for _ in queries]
            # Accumulated probabilities (or product of probabilities acquired from different segments) of punctuation
            # and capitalization. Probabilities for words in a query are extracted using `subtokens_mask`. Probabilities
            # for newly processed words are appended to the accumulated probabilities. If probabilities for a word are
            # already present in `acc_probs`, old probabilities are replaced with a product of old probabilities
            # and probabilities acquired from new segment. Segments are processed in an order they appear in an
            # input query. When all segments with a word are processed, a label with the highest probability
            # (or product of probabilities) is chosen and appended to an appropriate list in `all_preds`. After adding
            # prediction to `all_preds`, probabilities for a word are removed from `acc_probs`.
            acc_punct_probs: List[Optional[np.ndarray]] = [None for _ in queries]
            acc_capit_probs: List[Optional[np.ndarray]] = [None for _ in queries]
            d = self.device
            for batch_i, batch in tqdm(
                enumerate(infer_datalayer), total=ceil(len(infer_datalayer.dataset) / batch_size), unit="batch"
            ):
                inp_ids, inp_type_ids, inp_mask, subtokens_mask, start_word_ids, query_ids, is_first, is_last = batch
                punct_logits, capit_logits = self.forward(
                    input_ids=inp_ids.to(d), token_type_ids=inp_type_ids.to(d), attention_mask=inp_mask.to(d),
                )
                _res = self._transform_logit_to_prob_and_remove_margins_and_extract_word_probs(
                    punct_logits, capit_logits, subtokens_mask, start_word_ids, margin, is_first, is_last
                )
                punct_probs, capit_probs, start_word_ids = _res
                for i, (q_i, start_word_id, bpp_i, bcp_i) in enumerate(
                    zip(query_ids, start_word_ids, punct_probs, capit_probs)
                ):
                    for all_preds, acc_probs, b_probs_i in [
                        (all_punct_preds, acc_punct_probs, bpp_i),
                        (all_capit_preds, acc_capit_probs, bcp_i),
                    ]:
                        if acc_probs[q_i] is None:
                            acc_probs[q_i] = b_probs_i
                        else:
                            all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds(
                                all_preds[q_i], acc_probs[q_i], start_word_id - len(all_preds[q_i]),
                            )
                            acc_probs[q_i] = self._update_accumulated_probabilities(acc_probs[q_i], b_probs_i)
            for all_preds, acc_probs in [(all_punct_preds, acc_punct_probs), (all_capit_preds, acc_capit_probs)]:
                for q_i, (pred, prob) in enumerate(zip(all_preds, acc_probs)):
                    if prob is not None:
                        all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds(pred, prob, len(prob))
            for i, query in enumerate(queries):
                result.append(
                    self.get_labels(all_punct_preds[i], all_capit_preds[i])
                    if return_labels
                    else self.apply_punct_capit_predictions(query, all_punct_preds[i], all_capit_preds[i])
                )
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="punctuation_en_bert",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/punctuation_en_bert/versions/1.0.0rc1/files/punctuation_en_bert.nemo",
                description="The model was trained with NeMo BERT base uncased checkpoint on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="punctuation_en_distilbert",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/punctuation_en_distilbert/versions/1.0.0rc1/files/punctuation_en_distilbert.nemo",
                description="The model was trained with DiltilBERT base uncased checkpoint from HuggingFace on a subset of data from the following sources: Tatoeba sentences, books from Project Gutenberg, Fisher transcripts.",
            )
        )
        return result

    @property
    def input_module(self):
        return self.bert_model

    @property
    def output_module(self):
        return self
Пример #20
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None:
        super().__init__(cfg=cfg, trainer=trainer)

        label_map_file = self.register_artifact("label_map",
                                                cfg.label_map,
                                                verify_src_exists=True)
        semiotic_classes_file = self.register_artifact("semiotic_classes",
                                                       cfg.semiotic_classes,
                                                       verify_src_exists=True)
        self.label_map = read_label_map(label_map_file)
        self.semiotic_classes = read_semiotic_classes(semiotic_classes_file)

        self.num_labels = len(self.label_map)
        self.num_semiotic_labels = len(self.semiotic_classes)
        self.id_2_tag = {
            tag_id: tagging.Tag(tag)
            for tag, tag_id in self.label_map.items()
        }
        self.id_2_semiotic = {
            semiotic_id: semiotic
            for semiotic, semiotic_id in self.semiotic_classes.items()
        }
        self.max_sequence_len = cfg.get(
            'max_sequence_len', self.tokenizer.tokenizer.model_max_length)

        # setup to track metrics
        # we will have (len(self.semiotic_classes) + 1) labels
        # last one stands for WRONG (span in which the predicted tags don't match the labels)
        # this is needed to feed the sequence of classes to classification_report during validation
        label_ids = self.semiotic_classes.copy()
        label_ids["WRONG"] = len(self.semiotic_classes)
        self.tag_classification_report = ClassificationReport(
            len(self.semiotic_classes) + 1,
            label_ids=label_ids,
            mode='micro',
            dist_sync_on_step=True)
        self.tag_multiword_classification_report = ClassificationReport(
            len(self.semiotic_classes) + 1,
            label_ids=label_ids,
            mode='micro',
            dist_sync_on_step=True)
        self.semiotic_classification_report = ClassificationReport(
            len(self.semiotic_classes) + 1,
            label_ids=label_ids,
            mode='micro',
            dist_sync_on_step=True)

        self.hidden_size = cfg.hidden_size

        self.logits = TokenClassifier(self.hidden_size,
                                      num_classes=self.num_labels,
                                      num_layers=1,
                                      log_softmax=False,
                                      dropout=0.1)
        self.semiotic_logits = TokenClassifier(
            self.hidden_size,
            num_classes=self.num_semiotic_labels,
            num_layers=1,
            log_softmax=False,
            dropout=0.1)

        self.loss_fn = CrossEntropyLoss(logits_ndim=3)

        self.builder = bert_example.BertExampleBuilder(
            self.label_map, self.semiotic_classes, self.tokenizer.tokenizer,
            self.max_sequence_len)
Пример #21
0
class IntentSlotClassificationModel(NLPModel):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Setup tokenizer.
        self.setup_tokenizer(cfg.tokenizer)
        self.cfg = cfg
        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Disable setup methods.
            IntentSlotClassificationModel._set_model_restore_state(
                is_being_restored=True)
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds,
                                       cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=self.cfg.language_model.
            pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(self.cfg.language_model.config)
            if self.cfg.language_model.config else None,
            checkpoint_file=self.cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        # Enable setup methods.
        IntentSlotClassificationModel._set_model_restore_state(
            is_being_restored=False)

        # Initialize Classifier.
        self._reconfigure_classifier()

    def _set_defaults_data_desc(self, cfg):
        """
        Method makes sure that cfg.data_desc params are set.
        If not, set's them to "dummy" defaults.
        """
        if not hasattr(cfg, "data_desc"):
            OmegaConf.set_struct(cfg, False)
            cfg.data_desc = {}
            # Intents.
            cfg.data_desc.intent_labels = " "
            cfg.data_desc.intent_label_ids = {" ": 0}
            cfg.data_desc.intent_weights = [1]
            # Slots.
            cfg.data_desc.slot_labels = " "
            cfg.data_desc.slot_label_ids = {" ": 0}
            cfg.data_desc.slot_weights = [1]

            cfg.data_desc.pad_label = "O"
            OmegaConf.set_struct(cfg, True)

    def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds):
        """ Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc. """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = IntentSlotDataDesc(
            data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create({
                'intent_labels_file':
                'intent_labels.csv',
                'slot_labels_file':
                'slot_labels.csv'
            })

        slot_labels_file = os.path.join(data_dir,
                                        cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir,
                                          cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact('class_labels.intent_labels_file',
                               intent_labels_file)
        self.register_artifact('class_labels.slot_labels_file',
                               slot_labels_file)
        OmegaConf.set_struct(cfg, True)

    def _save_label_ids(self, label_ids: Dict[str, int],
                        filename: str) -> None:
        """ Saves label ids map to a file """
        with open(filename, 'w') as out:
            labels, _ = zip(*sorted(label_ids.items(), key=lambda x: x[1]))
            out.write('\n'.join(labels))
            logging.info(f'Labels: {label_ids}')
            logging.info(f'Labels mapping saved to : {out.name}')

    def _reconfigure_classifier(self):
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(num_inputs=2,
                                         weights=[
                                             self.cfg.intent_loss_weight,
                                             1.0 - self.cfg.intent_loss_weight
                                         ])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode='micro',
        )

    def update_data_dir_for_training(self, data_dir: str, train_ds,
                                     validation_ds) -> None:
        """
        Update data directory and get data stats with Data Descriptor.
        Also, reconfigures the classifier - to cope with data with e.g. different number of slots.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir
        # Update configuration with new data.
        self._set_data_desc_to_cfg(self.cfg, data_dir, train_ds, validation_ds)
        # Reconfigure the classifier for different settings (number of intents, slots etc.).
        self._reconfigure_classifier()

    def update_data_dir_for_testing(self, data_dir) -> None:
        """
        Update data directory.

        Args:
            data_dir: path to data directory
        """
        logging.info(f'Setting data_dir to {data_dir}.')
        self.data_dir = data_dir

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        intent_logits, slot_logits = self.classifier(
            hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        intent_preds = torch.argmax(intent_logits, axis=-1)
        self.intent_classification_report.update(intent_preds, intent_labels)
        # slots

        subtokens_mask = subtokens_mask > 0.5
        slot_preds = torch.argmax(slot_logits, axis=-1)
        self.slot_classification_report.update(slot_preds[subtokens_mask],
                                               slot_labels[subtokens_mask])

        return {
            'val_loss': val_loss,
            'intent_tp': self.intent_classification_report.tp,
            'intent_fn': self.intent_classification_report.fn,
            'intent_fp': self.intent_classification_report.fp,
            'slot_tp': self.slot_classification_report.tp,
            'slot_fn': self.slot_classification_report.fn,
            'slot_fp': self.slot_classification_report.fp,
            'intent_preds': intent_preds,
            'intent_labels': intent_labels,
            'slot_preds': slot_preds,
            'slot_labels': slot_labels,
            'input': input_ids,
            'subtokens_mask': subtokens_mask,
        }

    @staticmethod
    def get_continuous_slots(slot_ids, utterance_tokens):
        """
        Extract continuous spans of slot_ids
        Args:
            Slot_ids: list of str representing slot of each word token
            For instance, 'O', 'email_address', 'email_address', 'email_address', 'O', 'O', 'O', 'O']
            Corresponds to ['enter', 'atdfd@yahoo', 'dot', 'com', 'into', 'my', 'contact', 'list']
        Returns:
            list of str where each element is a slot name-value pair
            e.g. ['email_address(atdfd@yahoo dot com)']

        """
        slot_id_stack = []
        position_stack = []
        for i, slot_id in enumerate(slot_ids):
            if not slot_id_stack or slot_id != slot_id_stack[-1]:
                slot_id_stack.append(slot_id)
                position_stack.append([])
            position_stack[-1].append(i)

        slot_id_to_start_and_exclusive_end = {
            slot_id_stack[i]:
            [position_stack[i][0], position_stack[i][-1] + 1]
            for i in range(len(position_stack)) if slot_id_stack[i] != 'O'
        }

        slot_to_words = {
            slot: ' '.join(utterance_tokens[position[0]:position[1]])
            for slot, position in slot_id_to_start_and_exclusive_end.items()
        }

        slot_name_and_values = [
            "{}({})".format(slot, value)
            for slot, value in slot_to_words.items()
        ]

        return slot_name_and_values

    def get_unified_metrics(self, outputs):
        slot_preds = []
        slot_labels = []
        subtokens_mask = []
        inputs = []
        intent_preds = []
        intent_labels = []

        for output in outputs:
            slot_preds += output['slot_preds']
            slot_labels += output["slot_labels"]
            subtokens_mask += output["subtokens_mask"]
            inputs += output["input"]
            intent_preds += output["intent_preds"]
            intent_labels += output["intent_labels"]

        ground_truth_labels = self.convert_intent_ids_to_intent_names(
            intent_labels)
        generated_labels = self.convert_intent_ids_to_intent_names(
            intent_preds)

        predicted_slots = self.mask_unused_subword_slots(
            slot_preds, subtokens_mask)
        ground_truth_slots = self.mask_unused_subword_slots(
            slot_labels, subtokens_mask)

        all_generated_slots = []
        all_ground_truth_slots = []
        all_utterances = []

        for i in range(len(predicted_slots)):
            utterance = self.tokenizer.tokenizer.decode(
                inputs[i], skip_special_tokens=True)
            utterance_tokens = utterance.split()
            ground_truth_slot_names = ground_truth_slots[i].split()
            predicted_slot_names = predicted_slots[i].split()
            if len(utterance_tokens) != len(ground_truth_slot_names):
                # fix the bug that abc@xyz get tokenized to 3 tokens and @xyz to 2 tokens
                utterance_tokens = IntentSlotClassificationModel.join_tokens_containing_at_sign(
                    utterance_tokens, ground_truth_slot_names)
            processed_ground_truth_slots = IntentSlotClassificationModel.get_continuous_slots(
                ground_truth_slot_names, utterance_tokens)
            processed_predicted_slots = IntentSlotClassificationModel.get_continuous_slots(
                predicted_slot_names, utterance_tokens)

            all_generated_slots.append(processed_predicted_slots)
            all_ground_truth_slots.append(processed_ground_truth_slots)
            all_utterances.append(' '.join(utterance_tokens))

        os.makedirs(self.cfg.dataset.dialogues_example_dir, exist_ok=True)
        filename = os.path.join(self.cfg.dataset.dialogues_example_dir,
                                "predictions.jsonl")

        IntentSlotMetrics.save_predictions(
            filename,
            generated_labels,
            all_generated_slots,
            ground_truth_labels,
            all_ground_truth_slots,
            ['' for i in range(len(generated_labels))],
            ['' for i in range(len(generated_labels))],
            all_utterances,
        )

        slot_precision, slot_recall, slot_f1, slot_joint_goal_accuracy = IntentSlotMetrics.get_slot_filling_metrics(
            all_generated_slots, all_ground_truth_slots)

        return slot_precision, slot_recall, slot_f1, slot_joint_goal_accuracy

    @staticmethod
    def join_tokens_containing_at_sign(utterance_tokens, slot_names):
        """
        assumes utterance contains only one @ sign
        """
        target_length = len(slot_names)
        current_length = len(utterance_tokens)
        diff = current_length - target_length
        at_sign_positions = [
            index for index, token in enumerate(utterance_tokens)
            if token == "@"
        ]
        if len(at_sign_positions) > 1:
            raise ValueError(
                "Current method does not support utterances with more than 1 @ sign ({} encountered), please extend this method for utterance {} with slot names {}"
                .format(len(at_sign_positions), utterance_tokens, slot_names))
        elif diff == 1:
            new_tokens = []
            for index, token in enumerate(utterance_tokens):
                if utterance_tokens[index - 1] == "@":
                    new_tokens[-1] += token
                else:
                    new_tokens.append(token)

        elif diff == 2:
            new_tokens = []
            for index, token in enumerate(utterance_tokens[:-1]):
                if utterance_tokens[index - 1] == "@" or token == "@":
                    new_tokens[-1] += token
                else:
                    new_tokens.append(token)

        elif diff == 3:
            new_tokens = []
            for index, token in enumerate(utterance_tokens[:-1]):
                if utterance_tokens[index + 1] == "@" or utterance_tokens[
                        index - 1] == "@" or token == "@":
                    new_tokens[-1] += token
                else:
                    new_tokens.append(token)
        else:
            raise ValueError(
                "Difference of more than 3 ({}) encountered. please extend this method for utterance {} with slots {}"
                .format(diff, utterance_tokens, slot_names))

        return new_tokens

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """

        (
            unified_slot_precision,
            unified_slot_recall,
            unified_slot_f1,
            unified_slot_joint_goal_accuracy,
        ) = self.get_unified_metrics(outputs)

        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        intent_precision, intent_recall, intent_f1, intent_report = self.intent_classification_report.compute(
        )
        logging.info(f'Intent report: {intent_report}')

        slot_precision, slot_recall, slot_f1, slot_report = self.slot_classification_report.compute(
        )
        logging.info(f'Slot report: {slot_report}')

        self.log('val_loss', avg_loss)
        self.log('intent_precision', intent_precision)
        self.log('intent_recall', intent_recall)
        self.log('intent_f1', intent_f1)
        self.log('slot_precision', slot_precision)
        self.log('slot_recall', slot_recall)
        self.log('slot_f1', slot_f1)
        self.log('unified_slot_precision', unified_slot_precision)
        self.log('unified_slot_recall', unified_slot_recall)
        self.log('unified_slot_f1', unified_slot_f1)
        self.log('unified_slot_joint_goal_accuracy',
                 unified_slot_joint_goal_accuracy)

        self.intent_classification_report.reset()
        self.slot_classification_report.reset()

        return {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
            'unified_slot_precision': unified_slot_precision,
            'unified_slot_recall': unified_slot_recall,
            'unified_slot_f1': unified_slot_f1,
            'unified_slot_joint_goal_accuracy':
            unified_slot_joint_goal_accuracy,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, dataset_split='train')

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, dataset_split='dev')

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, dataset_split='test')

    def _setup_dataloader_from_config(self, cfg: DictConfig,
                                      dataset_split: str):
        data_processor = DialogueAssistantDataProcessor(
            self.data_dir, self.tokenizer)

        dataset = DialogueBERTDataset(
            dataset_split,
            data_processor,
            self.tokenizer,
            self.cfg.
            dataset,  # this is the model.dataset cfg, which is diff from train_ds cfg etc
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def _setup_infer_dataloader(self, queries: List[str],
                                test_ds) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: text
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """

        dataset = IntentSlotInferenceDataset(tokenizer=self.tokenizer,
                                             queries=queries,
                                             max_seq_length=-1,
                                             do_lower_case=False)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=test_ds.batch_size,
            shuffle=test_ds.shuffle,
            num_workers=test_ds.num_workers,
            pin_memory=test_ds.pin_memory,
            drop_last=test_ds.drop_last,
        )

    def update_data_dirs(self, data_dir: str, dialogues_example_dir: str):
        """
        Update data directories

        Args:
            data_dir: path to data directory
            dialogues_example_dir: path to preprocessed dialogues example directory, if not exists will be created.
        """
        if not os.path.exists(data_dir):
            raise ValueError(f"{data_dir} is not found")
        self.cfg.dataset.data_dir = data_dir
        self.cfg.dataset.dialogues_example_dir = dialogues_example_dir
        logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
        logging.info(
            f'Setting model.dataset.dialogues_example_dir to {dialogues_example_dir}.'
        )

    def predict_from_examples(self, queries: List[str],
                              test_ds) -> List[List[str]]:
        """
        Get prediction for the queries (intent and slots)
        Args:
            queries: text sequences
            test_ds: Dataset configuration section.
        Returns:
            predicted_intents, predicted_slots: model intent and slot predictions
        """

        predicted_intents = []
        predicted_slots = []
        mode = self.training

        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # Switch model to evaluation mode
        self.eval()
        self.to(device)

        # Dataset.
        infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

        for batch in infer_datalayer:
            input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

            intent_logits, slot_logits = self.forward(
                input_ids=input_ids.to(device),
                token_type_ids=input_type_ids.to(device),
                attention_mask=input_mask.to(device),
            )

            # predict intents
            intent_preds = tensor2list(torch.argmax(intent_logits, axis=-1))
            predicted_intents += self.convert_intent_ids_to_intent_names(
                intent_preds)

            # predict slots
            slot_preds = torch.argmax(slot_logits, axis=-1)
            predicted_slots += self.mask_unused_subword_slots(
                slot_preds, subtokens_mask)

        # set mode back to its original value
        self.train(mode=mode)

        return predicted_intents, predicted_slots

    def convert_intent_ids_to_intent_names(self, intent_preds):
        # Retrieve intent and slot vocabularies from configuration.
        intent_labels = self.cfg.data_desc.intent_labels

        predicted_intents = []

        # convert numerical outputs to Intent and Slot labels from the dictionaries
        for intent_num in intent_preds:
            # if intent_num < len(intent_labels):
            predicted_intents.append(intent_labels[int(intent_num)])
            # else:
            #     # should not happen
            #     predicted_intents.append("Unknown Intent")
        return predicted_intents

    def mask_unused_subword_slots(self, slot_preds, subtokens_mask):
        # Retrieve intent and slot vocabularies from configuration.
        slot_labels = self.cfg.data_desc.slot_labels
        predicted_slots = []
        for slot_preds_query, mask_query in zip(slot_preds, subtokens_mask):
            query_slots = ''
            for slot, mask in zip(slot_preds_query, mask_query):
                if mask == 1:
                    # if slot < len(slot_labels):
                    query_slots += slot_labels[int(slot)] + ' '
                    # else:
                    #     query_slots += 'Unknown_slot '
            predicted_slots.append(query_slots.strip())
        return predicted_slots

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description=
            "This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result
class PTuneTextClassificationModel(NLPModel, Exportable):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return {"sentences": [NeuralType(('T'), StringType())], "labels": [NeuralType(('T'), StringLabel())]}

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "floss": NeuralType((), LossType()),
            "returned_pred": NeuralType(('B'), PredictionsType()),
            "returned_label": NeuralType(('B'), PredictionsType()),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the PTune TextClassifier model."""
        super().__init__(cfg=cfg, trainer=trainer)

        initialize_model_parallel_for_nemo(
            world_size=trainer.world_size,
            global_rank=trainer.global_rank,
            local_rank=trainer.local_rank,
            tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
            seed=cfg.get('seed', 1234),
        )

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.tokenizer = get_nmt_tokenizer(
            library=cfg.tokenizer.library,
            model_name=cfg.tokenizer.type,
            tokenizer_model=self.register_artifact("tokenizer.model", cfg.tokenizer.model),
            vocab_file=self.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file),
            merges_file=self.register_artifact("tokenizer.merges_file", cfg.tokenizer.merge_file),
        )

        self.class_weights = None

        self.model = MegatronGPTModel.restore_from(
            self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)),
            trainer=trainer,
        )

        if not cfg.use_lm_finetune:
            self.model.freeze()

        hidden_size = self.model.cfg.hidden_size

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        self.classes = cfg.dataset.classes

        self.embeddings = self.model.model.language_model.embedding.word_embeddings

        # set allowed vocab set
        self.vocab = self.tokenizer.tokenizer.get_vocab()

        # make sure classes are part of the vocab
        for k in cfg.dataset.classes:
            if token_wrapper(k) not in self.vocab:
                logging.error(f'class {k} is not part of the vocabulary. Please add it to your vocab')
        self.allowed_vocab_ids = set(self.vocab[token_wrapper(k)] for k in cfg.dataset.classes)

        # map from id to label
        self.allowed_vocab = {}
        self.label_ids = {}
        self.id_to_label = {}
        for i, k in enumerate(cfg.dataset.classes):
            self.allowed_vocab[self.vocab[token_wrapper(k)]] = i
            self.label_ids[k] = i
            self.id_to_label[i] = k

        self.template = cfg.prompt_encoder.template

        self.prompt_encoder = PromptEncoder(
            template=cfg.prompt_encoder.template,
            hidden_size=hidden_size,
            lstm_dropout=cfg.prompt_encoder.dropout,
            num_layers=cfg.prompt_encoder.num_layers,
        )

        # load prompt encoder
        self.hidden_size = hidden_size
        self.tokenizer.add_special_tokens({'additional_special_tokens': [cfg.pseudo_token]})

        self.pseudo_token_id = self.tokenizer.tokenizer.get_vocab()[cfg.pseudo_token]
        self.pad_token_id = (
            self.tokenizer.tokenizer.pad_token_id
            if self.tokenizer.tokenizer.pad_token_id is not None
            else self.tokenizer.tokenizer.unk_token_id
        )
        self.spell_length = sum(self.template)

    def setup(self, stage):
        # setup to track metrics, need to put here
        # as data_parallel_group is initialized when calling `fit, or test function`
        app = AppState()
        self.classification_report = ClassificationReport(
            num_classes=len(self.classes),
            label_ids=self.label_ids,
            mode='micro',
            dist_sync_on_step=True,
            process_group=app.data_parallel_group,
        )

    def embed_input(self, queries):
        bz = queries.shape[0]
        queries_for_embedding = queries.clone()

        queries_for_embedding[(queries == self.pseudo_token_id)] = self.pad_token_id
        raw_embeds = self.embeddings(queries_for_embedding)

        blocked_indices = (
            (queries == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1]
        )  # bz
        replace_embeds = self.prompt_encoder()
        for bidx in range(bz):
            for i in range(self.prompt_encoder.spell_length):
                raw_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]
        return raw_embeds

    def get_query(self, x_h, prompt_tokens, x_t=None):
        max_seq_len = self.model._cfg.encoder_seq_length
        input_token_ids = self.tokenizer.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenizer.tokenize(' ' + x_h))
        cut = 0
        if len(input_token_ids) + sum(self.template) > max_seq_len:
            logging.warning("Input sequence is longer than the LM model max seq, will cut it off to fit")
            cut = len(input_token_ids) + sum(self.template) - max_seq_len
        return [
            prompt_tokens * self.template[0]
            + input_token_ids[cut:]  # head entity
            + prompt_tokens * self.template[1]
            + (
                self.tokenizer.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' ' + x_t))
                if x_t is not None
                else []
            )
        ]

    def get_ground_truth_labels(self, batch_size, label_ids):
        returned_label = []
        for i in range(batch_size):
            returned_label.append(self.allowed_vocab[label_ids[i, 0].item()])
        return torch.tensor(returned_label).to(self.device)

    def get_prediction(self, batch_size, label_position, logits):
        pred_ids = torch.argsort(logits, dim=2, descending=True)
        top10 = []
        returned_pred = []
        for i in range(batch_size):
            top10.append([])
            pred_seq = pred_ids[i, label_position[i, 0]].tolist()
            for pred in pred_seq:
                if pred in self.allowed_vocab_ids:
                    top10[-1].append(pred)
                    if len(top10[-1]) >= 10:
                        break
            pred = top10[-1][0]
            returned_pred.append(self.allowed_vocab[pred])
        return top10, torch.tensor(returned_pred).to(self.device)

    def get_encoder_input(self, sentences):
        batch_size = len(sentences)
        # construct query ids
        prompt_tokens = [self.pseudo_token_id]

        queries = [torch.LongTensor(self.get_query(sentences[i], prompt_tokens)).squeeze(0) for i in range(batch_size)]
        queries = pad_sequence(queries, True, padding_value=self.pad_token_id).long().to(self.device)

        # attention_mask indicates the boundary of attention
        attention_mask = queries != self.pad_token_id
        # get embedded input
        inputs_embeds = self.embed_input(queries)

        bz, seq_len, _ = inputs_embeds.shape

        # get the GPT causal mask
        causal_mask = torch.tril(torch.ones((bz, seq_len, seq_len), device=self.device)).view(bz, 1, seq_len, seq_len)
        # combine the attention_mask and causal_mask
        r = causal_mask.permute((1, 2, 0, 3)) * attention_mask.int()
        new_atten = r.permute((2, 0, 1, 3))
        # convert it to the boolean
        new_atten = new_atten < 0.5

        # calculate the position embedding based on the seq_len
        position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device)
        position_ids = position_ids.unsqueeze(0).expand_as(inputs_embeds[:, :, 0])
        position_embeddings = self.model.model.language_model.embedding.position_embeddings(position_ids)

        # get the final input for encoder
        encoder_input = inputs_embeds + position_embeddings

        # calculate the position of the output token
        label_position = (attention_mask.long().sum(dim=1) - 1).unsqueeze(1)
        return encoder_input, new_atten, label_position

    def get_label_input(self, labels, label_position, seq_len):
        batch_size, _ = label_position.shape
        x_ts = [token_wrapper(x_t) for x_t in labels]

        # construct label ids
        label_ids = (
            torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts))
            .reshape((batch_size, -1))
            .to(self.device)
        )
        labels = torch.zeros(batch_size, seq_len).to(self.device).fill_(SMALL_LOGITS).long()  # bz * seq_len
        labels = labels.scatter_(1, label_position, label_ids)
        return labels, label_ids

    def forward_eval(self, sentences):
        encoder_input, new_atten, label_position = self.get_encoder_input(sentences)
        batch_size, _, seq_len, _ = new_atten.shape

        # workaround to do auto-cast
        # get the LM dtype
        dtype = self.model.model.language_model.encoder.layers[0].dtype

        if dtype == torch.float32:
            output = self.model.model(
                None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device)
            )
        else:
            with torch.autocast(device_type="cuda", dtype=dtype):
                output = self.model.model(
                    None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device)
                )
        logits = output

        _, returned_pred = self.get_prediction(batch_size, label_position.to(self.device), logits)
        return returned_pred

    @typecheck()
    def forward(self, sentences, labels):
        encoder_input, new_atten, label_position = self.get_encoder_input(sentences)
        batch_size, _, seq_len, _ = new_atten.shape
        labels_input, label_ids = self.get_label_input(labels, label_position, seq_len)
        # workaround to do auto-cast
        # get the LM dtype
        dtype = self.model.model.language_model.encoder.layers[0].dtype

        if dtype == torch.float32:
            output = self.model.model(
                None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
            )
        else:
            with torch.autocast(device_type="cuda", dtype=dtype):
                output = self.model.model(
                    None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
                )
        loss, logits = output
        floss = (loss[(labels_input != SMALL_LOGITS)]).mean()

        _, returned_pred = self.get_prediction(batch_size, label_position, logits)
        returned_label = self.get_ground_truth_labels(batch_size, label_ids)
        return floss, returned_pred, returned_label

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        sentences, labels = batch
        train_loss, _, _ = self.forward(sentences=sentences, labels=labels)

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': train_loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        sentences, labels = batch
        val_loss, preds, gt_labels = self.forward(sentences=sentences, labels=labels)

        hit = 0
        for pred, gt_label in zip(preds, gt_labels):
            if pred == gt_label:
                hit += 1

        tp, fn, fp, _ = self.classification_report(preds, gt_labels)

        return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp, 'hit': hit}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        if not outputs:
            return {}
        if self.trainer.testing:
            prefix = 'test'
        else:
            prefix = 'val'

        avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean()

        total_hit = sum([x[f'hit'] for x in outputs])
        # calculate metrics and classification report
        precision, recall, f1, report = self.classification_report.compute()

        total_data = torch.sum(self.classification_report.num_examples_per_class)
        accuracy = total_hit / total_data.item()
        logging.info(f'{prefix}_report: {report}')
        logging.info(f'{total_hit} correct out of {total_data}, accuracy: {accuracy*100:.2f}')
        self.log(f'{prefix}_loss', avg_loss, prog_bar=True)
        self.log(f'{prefix}_accuracy', accuracy)
        self.log(f'{prefix}_precision', precision)
        self.log(f'{prefix}_f1', f1)
        self.log(f'{prefix}_recall', recall)

        self.classification_report.reset()

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or not test_data_config.file_path:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: Dict) -> 'torch.utils.data.DataLoader':
        input_file = cfg.file_path
        if not os.path.exists(input_file):
            raise FileNotFoundError(
                f'{input_file} not found! The data should be be stored in TAB-separated files \n\
                "validation_ds.file_path" and "train_ds.file_path" for train and evaluation respectively. \n\
                Each line of the files contains text sequences, where words are separated with spaces. \n\
                The label of the example is separated with TAB at the end of each line. \n\
                Each line of the files should follow the format: \n\
                [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]'
            )

        dataset = PTuneTextClassificationDataset(input_file)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
            collate_fn=dataset.collate_fn,
        )

    @torch.no_grad()
    def classifytext(self, queries: List[str], batch_size: int = 1, prompt: str = 'Sentiment') -> List[int]:
        """
        Get prediction for the queries
        Args:
            queries: text sequences
            batch_size: batch size to use during inference
            prompt: the prompt string appended at the end of your input sentence
        Returns:
            all_preds: model predictions
        """
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        try:
            # Switch model to evaluation mode
            self.eval()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {"batch_size": batch_size, "num_workers": 3, "pin_memory": False}
            infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, queries, prompt)
            for i, batch in enumerate(infer_datalayer):
                sentences, _ = batch
                preds = self.forward_eval(sentences)
                all_preds.extend([self.id_to_label[i.item()] for i in preds])
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)
        return all_preds

    def _setup_infer_dataloader(self, cfg: Dict, queries: List[str], prompt: str) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            cfg: config dictionary containing data loader params like batch_size, num_workers and pin_memory
            queries: text
            prompt: the prompt string appended at the end of your input sentence
        Returns:
            A pytorch DataLoader.
        """
        dataset = PTuneTextClassificationDataset(None, queries, prompt)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg["batch_size"],
            shuffle=False,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=False,
            collate_fn=dataset.collate_fn,
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
Пример #23
0
class DuplexTaggerModel(NLPModel):
    """
    Transformer-based (duplex) tagger model for TN/ITN.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer,
                                                        add_prefix_space=True)
        super().__init__(cfg=cfg, trainer=trainer)
        self.num_labels = len(constants.ALL_TAG_LABELS)
        self.mode = cfg.get('mode', 'joint')

        self.model = AutoModelForTokenClassification.from_pretrained(
            cfg.transformer, num_labels=self.num_labels)
        self.transformer_name = cfg.transformer
        self.max_sequence_len = cfg.get('max_sequence_len',
                                        self._tokenizer.model_max_length)

        # Loss Functions
        self.loss_fct = nn.CrossEntropyLoss(
            ignore_index=constants.LABEL_PAD_TOKEN_ID)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            self.num_labels,
            constants.LABEL_IDS,
            mode='micro',
            dist_sync_on_step=True)

        # Language
        self.lang = cfg.get('lang', None)

    # Training
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        num_labels = self.num_labels
        # Apply Transformer
        tag_logits = self.model(batch['input_ids'],
                                batch['attention_mask']).logits

        # Loss
        train_loss = self.loss_fct(tag_logits.view(-1, num_labels),
                                   batch['labels'].view(-1))

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)
        return {'loss': train_loss, 'lr': lr}

    # Validation and Testing
    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        # Apply Transformer
        tag_logits = self.model(batch['input_ids'],
                                batch['attention_mask']).logits
        tag_preds = torch.argmax(tag_logits, dim=2)

        # Update classification_report
        predictions, labels = tag_preds.tolist(), batch['labels'].tolist()
        for prediction, label in zip(predictions, labels):
            cur_preds = [
                p for (p, l) in zip(prediction, label)
                if l != constants.LABEL_PAD_TOKEN_ID
            ]
            cur_labels = [
                l for (p, l) in zip(prediction, label)
                if l != constants.LABEL_PAD_TOKEN_ID
            ]
            self.classification_report(
                torch.tensor(cur_preds).to(self.device),
                torch.tensor(cur_labels).to(self.device))

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        # calculate metrics and classification report
        precision, _, _, report = self.classification_report.compute()

        logging.info(report)

        self.log('val_token_precision', precision)

        self.classification_report.reset()

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    # Functions for inference
    @torch.no_grad()
    def _infer(self,
               sents: List[List[str]],
               inst_directions: List[str],
               do_basic_tokenization=True):
        """ Main function for Inference

        Args:
            sents: A list of inputs tokenized by a basic tokenizer.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance
                (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns:
            all_tag_preds: A list of list where each list contains the raw tag predictions for the corresponding input words in sents.
            nb_spans: A list of ints where each int indicates the number of semiotic spans in input words.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in input words.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in input words.
            do_basic_tokenization: whether to do a pre-processing to separate punctuation marks, recommended to set to True
        """
        self.eval()

        # Append prefix
        texts = []
        for ix, sent in enumerate(sents):
            if inst_directions[ix] == constants.INST_BACKWARD:
                prefix = constants.ITN_PREFIX
            elif inst_directions[ix] == constants.INST_FORWARD:
                prefix = constants.TN_PREFIX
            if do_basic_tokenization:
                texts.append([prefix] + sent)
            else:
                texts.append(prefix + " " + sent)

        # Apply the model
        if do_basic_tokenization:
            is_split_into_words = True
        else:
            is_split_into_words = False

        encodings = self._tokenizer(texts,
                                    is_split_into_words=is_split_into_words,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt')

        inputs = encodings
        encodings_reduced = None

        # check that the length of the 'input_ids' equals as least the length of the original input
        # if an input symbol is missing in the tokenizer's vocabulary (such as emoji or a Chinese character), it could be skipped
        if do_basic_tokenization:
            len_texts = [len(x) for x in texts]
        else:
            len_texts = [len(x.split()) for x in texts]
        len_ids = [
            len(
                self._tokenizer.convert_ids_to_tokens(
                    x, skip_special_tokens=True))
            for x in encodings['input_ids']
        ]
        idx_valid = [
            i for i, (t, enc) in enumerate(zip(len_texts, len_ids)) if enc >= t
        ]

        if len(idx_valid) != len(texts):
            logging.warning(
                'Some of the examples have symbols that were skipped during the tokenization. Such examples will be skipped.'
            )
            for i in range(len(texts)):
                if i not in idx_valid:
                    logging.warning(f'Invalid input: {texts[i]}')
            # skip these sentences and fall back to the input
            # exclude invalid examples from the encodings
            encodings_reduced = {
                k: tensor[idx_valid, :]
                for k, tensor in encodings.items()
            }
            for k, tensor in encodings_reduced.items():
                if tensor.ndim == 1:
                    encodings_reduced[k] = tensor.unsqueeze(dim=0)
            inputs = BatchEncoding(data=encodings_reduced)

        # skip the batch if no valid inputs are present
        if encodings_reduced and encodings_reduced['input_ids'].numel() == 0:
            # -1 to exclude tag for the prompt token
            all_tag_preds = [[constants.SAME_TAG] * (len(x) - 1)
                             for x in texts]
            nb_spans = [0] * len(texts)
            span_starts = [] * len(texts)
            span_ends = [] * len(texts)
            return all_tag_preds, nb_spans, span_starts, span_ends

        logits = self.model(**inputs.to(self.device)).logits
        pred_indexes = torch.argmax(logits, dim=-1).tolist()

        # Extract all_tag_preds for words
        all_tag_preds = []
        batch_size, max_len = encodings['input_ids'].size()
        pred_idx = 0
        for ix in range(batch_size):
            if ix in idx_valid:
                # remove first special token and task prefix token
                raw_tag_preds = [
                    constants.ALL_TAG_LABELS[p]
                    for p in pred_indexes[pred_idx][2:]
                ]
                tag_preds, previous_word_idx = [], None
                word_ids = encodings.word_ids(batch_index=ix)[2:]
                for jx, word_idx in enumerate(word_ids):
                    if word_idx is None:
                        continue
                    if word_idx != previous_word_idx:
                        tag_preds.append(raw_tag_preds[jx]
                                         )  # without special token at index 0
                    previous_word_idx = word_idx
                pred_idx += 1
            else:
                # for excluded examples, use SAME tags for all words
                tag_preds = [constants.SAME_TAG] * (len(texts[ix]) - 1)
            all_tag_preds.append(tag_preds)

        # Post-correction of simple tagger mistakes, i.e. I- tag is proceeding the B- tag in a span
        all_tag_preds = [
            self._postprocess_tag_preds(words, inst_dir, ps) for words,
            inst_dir, ps in zip(sents, inst_directions, all_tag_preds)
        ]

        # Decoding
        nb_spans, span_starts, span_ends = self.decode_tag_preds(all_tag_preds)
        return all_tag_preds, nb_spans, span_starts, span_ends

    def _postprocess_tag_preds(self, words: List[str], inst_dir: str,
                               preds: List[str]):
        """ Function for postprocessing the raw tag predictions of the model. It
        corrects obvious mistakes in the tag predictions such as a TRANSFORM span
        starts with I_TRANSFORM_TAG (instead of B_TRANSFORM_TAG).

        Args:
            words: The words in the input sentence
            inst_dir: The direction of the instance (i.e., constants.INST_BACKWARD or INST_FORWARD).
            preds: The raw tag predictions

        Returns: The processed raw tag predictions
        """
        final_preds = []
        for ix, p in enumerate(preds):
            # a TRANSFORM span starts with I_TRANSFORM_TAG, change to B_TRANSFORM_TAG
            if p == constants.I_PREFIX + constants.TRANSFORM_TAG:
                if ix == 0 or (not constants.TRANSFORM_TAG
                               in final_preds[ix - 1]):
                    final_preds.append(constants.B_PREFIX +
                                       constants.TRANSFORM_TAG)
                    continue
            # a span has numbers but does not have TRANSFORM tags (for TN)
            if inst_dir == constants.INST_FORWARD:
                if has_numbers(
                        words[ix]) and (not constants.TRANSFORM_TAG in p):
                    final_preds.append(constants.B_PREFIX +
                                       constants.TRANSFORM_TAG)
                    continue
            # Convert B-TASK tag to B-SAME tag
            if p == constants.B_PREFIX + constants.TASK_TAG:
                final_preds.append(constants.B_PREFIX + constants.SAME_TAG)
                continue
            # Default
            final_preds.append(p)
        return final_preds

    def decode_tag_preds(self, tag_preds: List[List[str]]):
        """ Decoding the raw tag predictions to locate the semiotic spans in the
        input texts.

        Args:
            tag_preds: A list of list where each list contains the raw tag predictions for the corresponding input words.

        Returns:
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input words.
            span_ends: A list of lists where each list contains the inclusive ending locations of semiotic spans in an input words.
        """
        nb_spans, span_starts, span_ends = [], [], []
        for i, preds in enumerate(tag_preds):
            cur_nb_spans, cur_span_start = 0, None
            cur_span_starts, cur_span_ends = [], []
            for ix, pred in enumerate(preds + ['EOS']):
                if pred != constants.I_PREFIX + constants.TRANSFORM_TAG:
                    if not cur_span_start is None:
                        cur_nb_spans += 1
                        cur_span_starts.append(cur_span_start)
                        cur_span_ends.append(ix - 1)
                    cur_span_start = None
                if pred == constants.B_PREFIX + constants.TRANSFORM_TAG:
                    cur_span_start = ix
            nb_spans.append(cur_nb_spans)
            span_starts.append(cur_span_starts)
            span_ends.append(cur_span_ends)
        return nb_spans, span_starts, span_ends

    # Functions for processing data
    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for train is created!"
            )
            self._train_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, data_split="train")

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!"
            )
            self._validation_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, data_split="val")

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or test_data_config.data_path is None:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, data_split="test")

    def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str):
        start_time = perf_counter()
        logging.info(f'Creating {data_split} dataset')
        input_file = cfg.data_path
        tagger_data_augmentation = cfg.get('tagger_data_augmentation', False)
        dataset = TextNormalizationTaggerDataset(
            input_file=input_file,
            tokenizer=self._tokenizer,
            tokenizer_name=self.transformer_name,
            mode=self.mode,
            do_basic_tokenize=cfg.do_basic_tokenize,
            tagger_data_augmentation=tagger_data_augmentation,
            lang=self.lang,
            max_seq_length=self.max_sequence_len,
            use_cache=cfg.get('use_cache', False),
            max_insts=cfg.get('max_insts', -1),
        )
        data_collator = DataCollatorForTokenClassification(self._tokenizer)
        dl = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=cfg.batch_size,
                                         shuffle=cfg.shuffle,
                                         collate_fn=data_collator)
        running_time = perf_counter() - start_time
        logging.info(f'Took {running_time} seconds')
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        result = []
        return result
Пример #24
0
    def validation_epoch_end(self, outputs):
        """
        Get metrics based on the candidate label with the highest predicted likelihood and the ground truth label for intent
        """
        output_logits = torch.cat([output['logits'] for output in outputs],
                                  dim=0)
        output_input_ids = torch.cat(
            [output['input_ids'] for output in outputs], dim=0)
        output_labels = torch.cat([output['labels'] for output in outputs],
                                  dim=0)

        if self.cfg.library == 'huggingface':
            entail_logits = output_logits[..., 2]
            decoded_input_ids = [
                self.tokenizer.decode(output_input_ids[i])
                for i in range(len(output_input_ids))
            ]
            utterance_candidate_pairs = [
                i.split(self.tokenizer.sep_token) for i in decoded_input_ids
            ]
            utterances = [
                i[0].replace(self.tokenizer.bos_token,
                             '').replace(self.tokenizer.eos_token, '')
                for i in utterance_candidate_pairs
            ]

        elif self.cfg.library == 'megatron':
            entail_logits = output_logits[..., 1]
            decoded_input_ids = [
                self.tokenizer.tokenizer.decode(output_input_ids[i])
                for i in range(len(output_input_ids))
            ]
            utterance_candidate_pairs = [
                i.split(self.tokenizer.tokenizer.sep_token)
                for i in decoded_input_ids
            ]
            utterances = [
                i[0].replace(self.tokenizer.tokenizer.bos_token,
                             '').replace(self.tokenizer.tokenizer.eos_token,
                                         '') for i in utterance_candidate_pairs
            ]

        # account for uncased tokenization
        candidates = [
            i[1].replace(self.cfg.dataset.prompt_template.lower(),
                         '').replace(self.cfg.dataset.prompt_template,
                                     '').strip()
            for i in utterance_candidate_pairs
        ]
        utterance_to_idx = defaultdict(list)
        for idx, utterance in enumerate(utterances):
            utterance_to_idx[utterance].append(idx)

        predicted_labels = []
        ground_truth_labels = []
        utterances = []
        for utterance, idxs in utterance_to_idx.items():
            utterance_candidates = [candidates[idx] for idx in idxs]
            logits = [entail_logits[idx].item() for idx in idxs]
            labels = [output_labels[idx].item() for idx in idxs]
            correct_candidate = utterance_candidates[np.argmax(labels)]
            predicted_candidate = utterance_candidates[np.argmax(logits)]
            predicted_labels.append(predicted_candidate)
            ground_truth_labels.append(correct_candidate)
            utterances.append(utterance)

        os.makedirs(self.cfg.dataset.dialogues_example_dir, exist_ok=True)
        filename = os.path.join(self.cfg.dataset.dialogues_example_dir,
                                "test_predictions.jsonl")

        DialogueGenerationMetrics.save_predictions(
            filename,
            predicted_labels,
            ground_truth_labels,
            utterances,
        )

        label_to_ids = {
            label: idx
            for idx, label in enumerate(
                list(set(predicted_labels + ground_truth_labels)))
        }
        self.classification_report = ClassificationReport(
            num_classes=len(label_to_ids),
            mode='micro',
            label_ids=label_to_ids,
            dist_sync_on_step=True).to(output_logits[0].device)
        predicted_label_ids = torch.tensor([
            label_to_ids[label] for label in predicted_labels
        ]).to(output_logits[0].device)
        ground_truth_label_ids = torch.tensor([
            label_to_ids[label] for label in ground_truth_labels
        ]).to(output_logits[0].device)

        tp, fn, fp, _ = self.classification_report(predicted_label_ids,
                                                   ground_truth_label_ids)
        precision, recall, f1, report = self.classification_report.compute()
        label_acc = np.mean([
            int(predicted_labels[i] == ground_truth_labels[i])
            for i in range(len(predicted_labels))
        ])

        avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean()

        logging.info(report)

        self.log('unified_precision', precision)
        self.log('unified_f1', f1)
        self.log('unified_recall', recall)
        self.log('unfied_accuracy', label_acc * 100)
        self.log('val_loss', avg_loss, prog_bar=True)

        self.classification_report.reset()
Пример #25
0
class DialogueZeroShotIntentModel(TextClassificationModel):
    """TextClassificationModel to be trained on two- or three-class textual entailment data, to be used for zero shot intent recognition."""
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self.cfg = cfg
        super().__init__(cfg=cfg, trainer=trainer)

        if self.cfg.library == 'megatron':
            # zero shot intent classification loading
            # cannot directly load as .nemo uses the pre-refactor model
            # therefore transfer its attributes over
            if self.cfg.original_nemo_checkpoint is not None:
                original_model = DialogueZeroShotIntentModel.restore_from(
                    self.cfg.original_nemo_checkpoint)
                self.classifier = original_model.classifier
                self.bert_model = original_model.bert_model
                self.loss = original_model.loss
                self.classification_report = original_model.classification_report
        elif self.cfg.library == "huggingface":
            self.nli_model = AutoModelForSequenceClassification.from_pretrained(
                'facebook/bart-large-mnli')
            self.bert_model = self.nli_model.model
            self.classifier = self.nli_model.classification_head
            original_model = DialogueZeroShotIntentModel.restore_from(
                self.cfg.original_nemo_checkpoint)
            self.loss = original_model.loss
            self.classification_report = original_model.classification_report
            self.tokenizer = AutoTokenizer.from_pretrained(
                'facebook/bart-large-mnli')
            self.tokenizer.max_seq_length = self.cfg.dataset.max_seq_length

    def _setup_dataloader_from_config(
            self, cfg: DictConfig,
            dataset_split) -> 'torch.utils.data.DataLoader':
        if self._cfg.dataset.task == "zero_shot":
            self.data_processor = DialogueAssistantDataProcessor(
                self.cfg.data_dir, self.tokenizer, cfg=self.cfg.dataset)
        elif self._cfg.dataset.task == "design":
            self.data_processor = DialogueDesignDataProcessor(
                data_dir=self._cfg.dataset.data_dir,
                tokenizer=self.tokenizer,
                cfg=self._cfg.dataset)
        elif self._cfg.dataset.task == 'sgd':
            self.data_processor = DialogueSGDDataProcessor(
                data_dir=self._cfg.dataset.data_dir,
                dialogues_example_dir=self._cfg.dataset.dialogues_example_dir,
                tokenizer=self.tokenizer,
                cfg=self._cfg.dataset,
            )
        else:
            raise ValueError(
                "Only zero_shot, design and sgd supported for Zero Shot Intent Model"
            )

        dataset = DialogueZeroShotIntentDataset(
            dataset_split,
            self.data_processor,
            self.tokenizer,
            self.cfg.
            dataset,  # this is the model.dataset cfg, which is diff from train_ds cfg etc
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    def forward(self, input_ids, attention_mask, token_type_ids):
        if self.cfg.library == 'megatron':
            hidden_states = self.bert_model(input_ids=input_ids,
                                            token_type_ids=token_type_ids,
                                            attention_mask=attention_mask)
            if isinstance(hidden_states, tuple):
                hidden_states = hidden_states[0]
            logits = self.classifier(hidden_states=hidden_states)
        elif self.cfg.library == 'huggingface':
            output = self.nli_model(input_ids=input_ids,
                                    attention_mask=attention_mask)
            logits = output['logits']
        return logits

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config:
            logging.info(
                f"Dataloader config or file_name for the training set is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            train_data_config, "train")

        # calculate the class weights to be used in the loss function
        if self.cfg.dataset.class_balancing == 'weighted_loss':
            self.class_weights = calc_class_weights_from_dataloader(
                self._train_dl, self.cfg.dataset.num_classes,
                self.cfg.dataset.data_dir)
        else:
            self.class_weights = None
        # we need to create/update the loss module by using the weights calculated from the training data
        self.create_loss_module()

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config:
            logging.info(
                f"Dataloader config or file_path for the validation data set is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            val_data_config, "dev")

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config:
            logging.info(
                f"Dataloader config or file_path for the test data set is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            test_data_config, "test")

    def _setup_infer_dataloader(
        self,
        queries: List[str],
        candidate_labels: List[str],
        hypothesis_template=str,
        batch_size=1,
        max_seq_length: int = -1,
    ) -> 'torch.utils.data.DataLoader':
        """
        Setup method for inference data loader. Here the premise-hypothesis pairs are made from queries and candidate labels.

        Args:
            queries: the queries to classify
            candidate_labels: strings to be used as labels
            hypothesis_template: the template used to turn each label into an NLI-style hypothesis. Must include a {}
                or similar syntax for the candidate label to be inserted.
            batch_size: batch size to use during inference
            max_seq_length: maximum length of queries, default is -1 for no limit
        Returns:
            A pytorch DataLoader.
        """
        dataset = ZeroShotIntentInferenceDataset(
            queries=queries,
            candidate_labels=candidate_labels,
            tokenizer=self.tokenizer,
            max_seq_length=max_seq_length,
            hypothesis_template=hypothesis_template,
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=False,
            drop_last=False,
            collate_fn=dataset.collate_fn,
        )

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, labels = batch
        logits = self.forward(input_ids=input_ids,
                              token_type_ids=input_type_ids,
                              attention_mask=input_mask)

        val_loss = self.loss(logits=logits, labels=labels)

        preds = torch.argmax(logits, axis=-1)

        tp, fn, fp, _ = self.classification_report(preds, labels)

        return {
            'val_loss': val_loss,
            'tp': tp,
            'fn': fn,
            'fp': fp,
            'logits': logits,
            'input_ids': input_ids,
            'labels': labels,
        }

    def validation_epoch_end(self, outputs):
        """
        Get metrics based on the candidate label with the highest predicted likelihood and the ground truth label for intent
        """
        output_logits = torch.cat([output['logits'] for output in outputs],
                                  dim=0)
        output_input_ids = torch.cat(
            [output['input_ids'] for output in outputs], dim=0)
        output_labels = torch.cat([output['labels'] for output in outputs],
                                  dim=0)

        if self.cfg.library == 'huggingface':
            entail_logits = output_logits[..., 2]
            decoded_input_ids = [
                self.tokenizer.decode(output_input_ids[i])
                for i in range(len(output_input_ids))
            ]
            utterance_candidate_pairs = [
                i.split(self.tokenizer.sep_token) for i in decoded_input_ids
            ]
            utterances = [
                i[0].replace(self.tokenizer.bos_token,
                             '').replace(self.tokenizer.eos_token, '')
                for i in utterance_candidate_pairs
            ]

        elif self.cfg.library == 'megatron':
            entail_logits = output_logits[..., 1]
            decoded_input_ids = [
                self.tokenizer.tokenizer.decode(output_input_ids[i])
                for i in range(len(output_input_ids))
            ]
            utterance_candidate_pairs = [
                i.split(self.tokenizer.tokenizer.sep_token)
                for i in decoded_input_ids
            ]
            utterances = [
                i[0].replace(self.tokenizer.tokenizer.bos_token,
                             '').replace(self.tokenizer.tokenizer.eos_token,
                                         '') for i in utterance_candidate_pairs
            ]

        # account for uncased tokenization
        candidates = [
            i[1].replace(self.cfg.dataset.prompt_template.lower(),
                         '').replace(self.cfg.dataset.prompt_template,
                                     '').strip()
            for i in utterance_candidate_pairs
        ]
        utterance_to_idx = defaultdict(list)
        for idx, utterance in enumerate(utterances):
            utterance_to_idx[utterance].append(idx)

        predicted_labels = []
        ground_truth_labels = []
        utterances = []
        for utterance, idxs in utterance_to_idx.items():
            utterance_candidates = [candidates[idx] for idx in idxs]
            logits = [entail_logits[idx].item() for idx in idxs]
            labels = [output_labels[idx].item() for idx in idxs]
            correct_candidate = utterance_candidates[np.argmax(labels)]
            predicted_candidate = utterance_candidates[np.argmax(logits)]
            predicted_labels.append(predicted_candidate)
            ground_truth_labels.append(correct_candidate)
            utterances.append(utterance)

        os.makedirs(self.cfg.dataset.dialogues_example_dir, exist_ok=True)
        filename = os.path.join(self.cfg.dataset.dialogues_example_dir,
                                "test_predictions.jsonl")

        DialogueGenerationMetrics.save_predictions(
            filename,
            predicted_labels,
            ground_truth_labels,
            utterances,
        )

        label_to_ids = {
            label: idx
            for idx, label in enumerate(
                list(set(predicted_labels + ground_truth_labels)))
        }
        self.classification_report = ClassificationReport(
            num_classes=len(label_to_ids),
            mode='micro',
            label_ids=label_to_ids,
            dist_sync_on_step=True).to(output_logits[0].device)
        predicted_label_ids = torch.tensor([
            label_to_ids[label] for label in predicted_labels
        ]).to(output_logits[0].device)
        ground_truth_label_ids = torch.tensor([
            label_to_ids[label] for label in ground_truth_labels
        ]).to(output_logits[0].device)

        tp, fn, fp, _ = self.classification_report(predicted_label_ids,
                                                   ground_truth_label_ids)
        precision, recall, f1, report = self.classification_report.compute()
        label_acc = np.mean([
            int(predicted_labels[i] == ground_truth_labels[i])
            for i in range(len(predicted_labels))
        ])

        avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean()

        logging.info(report)

        self.log('unified_precision', precision)
        self.log('unified_f1', f1)
        self.log('unified_recall', recall)
        self.log('unfied_accuracy', label_acc * 100)
        self.log('val_loss', avg_loss, prog_bar=True)

        self.classification_report.reset()

    def predict(
        self,
        queries: Union[str, List[str]],
        candidate_labels: Union[str, List[str]],
        hypothesis_template='This example is {}.',
        batch_size=1,
        multi_label=True,
        entailment_idx=1,
        contradiction_idx=0,
    ) -> List[Dict]:
        """
        Given a list of queries and a list of candidate labels, return a ranked list of labels and scores for each query.

        Example usage:
            queries = ["I'd like a veggie burger, fries, and a coke", "Turn off the lights in the living room",]
            candidate_labels = ["Food order", "Change lighting"]
            model.predict(queries, candidate_labels)

        Example output:
            [{'sentence': "I'd like a veggie burger, fries, and a coke",
              'labels': ['Food order', 'Change lighting'],
              'scores': [0.8557153344154358, 0.12036784738302231]},
             {'sentence': 'Turn off the lights in the living room',
              'labels': ['Change lighting', 'Food order'],
              'scores': [0.8506497144699097, 0.06594637036323547]}]


        Args:
            queries: the query or list of queries to classify
            candidate_labels: string or list of strings to be used as labels
            hypothesis_template: the template used to turn each label into an NLI-style hypothesis. Must include a {}
            or similar syntax for the candidate label to be inserted.
            batch_size: the batch size to use for inference.
            multi_label: whether or not multiple candidate labels can be true. If False, the scores are normalized
            such that all class probabilities sum to 1. If True, the labels are
            considered independent and probabilities are normalized for each candidate by doing a softmax of
            the entailment score vs. the contradiction score.
            entailment_idx: the index of the "entailment" class in the trained model; models trained on MNLI
             using NeMo's glue_benchmark.py or zero_shot_intent_model.py use an index of 1 by default.
            contradiction_idx: the index of the "contradiction" class in the trained model; models trained on MNLI
             using NeMo's glue_benchmark.py or zero_shot_intent_model.py use an index of 0 by default.

        Returns:
            list of dictionaries; one dict per input query. Each dict has keys "sentence", "labels", "scores".
            labels and scores are parallel lists (with each score corresponding to the label at the same index),
                 sorted from highest to lowest score.

        """
        if not queries:
            raise ValueError("No queries were passed for classification!")
        if not candidate_labels:
            raise ValueError("No candidate labels were provided!")

        queries = [queries] if isinstance(queries, str) else queries
        candidate_labels = [candidate_labels] if isinstance(
            candidate_labels, str) else candidate_labels

        if len(candidate_labels) == 1:
            multi_label = True

        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            infer_datalayer = self._setup_infer_dataloader(
                queries,
                candidate_labels,
                hypothesis_template=hypothesis_template,
                batch_size=batch_size,
                max_seq_length=self._cfg.dataset.max_seq_length,
            )

            all_batch_logits = []
            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, _ = batch

                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )
                all_batch_logits.append(logits.detach().cpu().numpy())

            all_logits = np.concatenate(all_batch_logits)
            outputs = all_logits.reshape(
                (len(queries), len(candidate_labels), -1))

            if not multi_label:
                # softmax the "entailment" logits over all candidate labels
                entail_logits = outputs[..., entailment_idx]
                scores = np.exp(entail_logits) / np.exp(entail_logits).sum(
                    -1, keepdims=True)
            else:
                # softmax over the entailment vs. contradiction dim for each label independently
                entail_contr_logits = outputs[
                    ..., [contradiction_idx, entailment_idx]]
                scores = np.exp(entail_contr_logits) / np.exp(
                    entail_contr_logits).sum(-1, keepdims=True)
                scores = scores[..., 1]

            result = []
            for i in range(len(queries)):
                sorted_idxs = list(reversed(scores[i].argsort()))
                result.append({
                    "sentence":
                    queries[i],
                    "labels": [candidate_labels[j] for j in sorted_idxs],
                    "scores":
                    scores[i][sorted_idxs].tolist(),
                })

        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained models which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="zeroshotintent_en_bert_base_uncased",
                location=
                "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/zeroshotintent_en_bert_base_uncased/versions/1.4.1/files/zeroshotintent_en_bert_base_uncased.nemo",
                description=
                "DialogueZeroShotIntentModel trained by fine tuning BERT-base-uncased on the MNLI (Multi-Genre Natural Language Inference) dataset, which achieves an accuracy of 84.9% and 84.8% on the matched and mismatched dev sets, respectively.",
            ))
        result.append(
            PretrainedModelInfo(
                pretrained_model_name="zeroshotintent_en_megatron_uncased",
                location=
                "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/zeroshotintent_en_megatron_uncased/versions/1.4.1/files/zeroshotintent_en_megatron_uncased.nemo",
                description=
                "DialogueZeroShotIntentModel trained by fine tuning Megatron-BERT-345m=M-uncased on the MNLI (Multi-Genre Natural Language Inference) dataset, which achieves an accuracy of 90.0% and 89.9% on the matched and mismatched dev sets, respectively",
            ))
        return result
    def validation_epoch_end(self, outputs):
        """
        Get metrics based on the candidate label with the highest predicted likelihood and the ground truth label for intent
        """
        output_preds = torch.cat([output['preds'] for output in outputs],
                                 dim=0)
        output_labels = torch.cat([output['labels'] for output in outputs],
                                  dim=0)
        inputs = torch.cat([output['inputs'] for output in outputs], dim=0)

        decoded_preds = self.tokenizer.tokenizer.batch_decode(
            output_preds, skip_special_tokens=True)
        decoded_labels = self.tokenizer.tokenizer.batch_decode(
            output_labels, skip_special_tokens=True)
        decoded_inputs = self.tokenizer.tokenizer.batch_decode(
            inputs, skip_special_tokens=True)

        prompt_len = len(self.cfg.dataset.prompt_template.strip())
        predicted_labels = [i[prompt_len:].strip() for i in decoded_preds]
        ground_truth_labels = [i[prompt_len:].strip() for i in decoded_labels]

        os.makedirs(self.cfg.dataset.dialogues_example_dir, exist_ok=True)
        filename = os.path.join(self.cfg.dataset.dialogues_example_dir,
                                "test_predictions.jsonl")

        DialogueGenerationMetrics.save_predictions(
            filename,
            predicted_labels,
            ground_truth_labels,
            decoded_inputs,
        )

        label_to_ids = {
            label: idx
            for idx, label in enumerate(
                list(set(predicted_labels + ground_truth_labels)))
        }
        self.classification_report = ClassificationReport(
            num_classes=len(label_to_ids),
            mode='micro',
            label_ids=label_to_ids,
            dist_sync_on_step=True).to(output_preds[0].device)

        predicted_label_ids = torch.tensor([
            label_to_ids[label] for label in predicted_labels
        ]).to(output_preds[0].device)
        ground_truth_label_ids = torch.tensor([
            label_to_ids[label] for label in ground_truth_labels
        ]).to(output_preds[0].device)

        tp, fn, fp, _ = self.classification_report(predicted_label_ids,
                                                   ground_truth_label_ids)

        precision, recall, f1, report = self.classification_report.compute()
        label_acc = np.mean([
            int(predicted_labels[i] == ground_truth_labels[i])
            for i in range(len(predicted_labels))
        ])

        logging.info(report)

        self.log('unified_precision', precision)
        self.log('unified_f1', f1)
        self.log('unified_recall', recall)
        self.log('unfied_accuracy', label_acc * 100)

        self.classification_report.reset()
Пример #27
0
class MultiLabelIntentSlotClassificationModel(IntentSlotClassificationModel):
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ 
        Initializes BERT Joint Intent and Slot model.

        Args: 
            cfg: configuration object
            trainer: trainer for Pytorch Lightning
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Optimal Threshold
        self.threshold = 0.5
        self.max_f1 = 0

        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds, cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize Classifier.
        self._reconfigure_classifier()

    def _set_data_desc_to_cfg(
        self, cfg: DictConfig, data_dir: str, train_ds: DictConfig, validation_ds: DictConfig
    ) -> None:
        """ 
        Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. 
        
        Args: 
            cfg: configuration object
            data_dir: data directory 
            train_ds: training dataset file name
            validation_ds: validation dataset file name

        Returns:
            None
        """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = MultiLabelIntentSlotDataDesc(data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create(
                {"intent_labels_file": "intent_labels.csv", "slot_labels_file": "slot_labels.csv",}
            )

        slot_labels_file = os.path.join(data_dir, cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir, cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact("class_labels.intent_labels_file", intent_labels_file)
        self.register_artifact("class_labels.slot_labels_file", slot_labels_file)
        OmegaConf.set_struct(cfg, True)

    def _reconfigure_classifier(self) -> None:
        """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=len(self.cfg.data_desc.intent_labels),
            num_slots=len(self.cfg.data_desc.slot_labels),
            dropout=self.cfg.head.fc_dropout,
            num_layers=self.cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if self.cfg.class_balancing == "weighted_loss":
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = BCEWithLogitsLoss(logits_ndim=2, pos_weight=self.cfg.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3, weight=self.cfg.data_desc.slot_weights)
        else:
            self.intent_loss = BCEWithLogitsLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2, weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight],
        )

        # setup to track metrics
        self.intent_classification_report = MultiLabelClassificationReport(
            num_classes=len(self.cfg.data_desc.intent_labels),
            label_ids=self.cfg.data_desc.intent_label_ids,
            dist_sync_on_step=True,
            mode="micro",
        )
        self.slot_classification_report = ClassificationReport(
            num_classes=len(self.cfg.data_desc.slot_labels),
            label_ids=self.cfg.data_desc.slot_label_ids,
            dist_sync_on_step=True,
            mode="micro",
        )

    def validation_step(self, batch, batch_idx) -> None:
        """
        Validation Loop. Pytorch Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.

        Args:
            batch: batches of data from DataLoader
            batch_idx: batch idx from DataLoader

        Returns: 
            None
        """
        (input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels,) = batch
        intent_logits, slot_logits = self(
            input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask,
        )

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits, labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits, labels=slot_labels, loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        intent_probabilities = torch.round(torch.sigmoid(intent_logits))

        self.intent_classification_report.update(intent_probabilities, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        self.slot_classification_report.update(preds, slot_labels)

        return {
            "val_loss": val_loss,
            "intent_tp": self.intent_classification_report.tp,
            "intent_fn": self.intent_classification_report.fn,
            "intent_fp": self.intent_classification_report.fp,
            "slot_tp": self.slot_classification_report.tp,
            "slot_fn": self.slot_classification_report.fn,
            "slot_fp": self.slot_classification_report.fp,
        }

    def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader:
        """
        Creates the DataLoader from the configuration object

        Args:
            cfg: configuration object
        
        Returns:
            DataLoader for model's data
        """

        input_file = f"{self.data_dir}/{cfg.prefix}.tsv"
        slot_file = f"{self.data_dir}/{cfg.prefix}_slots.tsv"
        intent_dict_file = self.data_dir + "/dict.intents.csv"

        lines = open(intent_dict_file, "r").readlines()
        lines = [line.strip() for line in lines if line.strip()]
        num_intents = len(lines)

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f"{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files."
            )

        dataset = MultiLabelIntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            num_intents=num_intents,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.cfg.data_desc.pad_label,
            ignore_extra_tokens=self.cfg.ignore_extra_tokens,
            ignore_start_end=self.cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    def prediction_probabilities(self, queries: List[str], test_ds: DictConfig) -> npt.NDArray:
        """
        Get prediction probabilities for the queries (intent and slots)

        Args:
            queries: text sequences
            test_ds: Dataset configuration section.

        Returns:
            numpy array of intent probabilities
        """

        probabilities = []

        mode = self.training
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            # Dataset.
            infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents for these examples
                probabilities.append(torch.sigmoid(intent_logits).detach().cpu().numpy())

            probabilities = np.concatenate(probabilities)

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return probabilities

    def optimize_threshold(self, test_ds: DictConfig, file_name: str) -> None:
        """
        Set the optimal threshold of the model from performance on validation set. This threshold is used to round the 
        logits to 0 or 1. 

        Args:
            test_ds: location of test dataset
            file_name: name of input file to retrieve validation set

        Returns:
            None
        """

        input_file = f"{self.data_dir}/{file_name}.tsv"

        with open(input_file, "r") as f:
            input_lines = f.readlines()[1:]  # Skipping headers at index 0

        dataset = list(input_lines)

        metrics_labels, sentences = [], []

        for input_line in dataset:
            sentence = input_line.strip().split("\t")[0]
            sentences.append(sentence)
            parts = input_line.strip().split("\t")[1:][0]
            parts = list(map(int, parts.split(",")))
            parts = [1 if label in parts else 0 for label in range(len(self.cfg.data_desc.intent_labels))]
            metrics_labels.append(parts)

        # Retrieve class probabilities for each sentence
        intent_probabilities = self.prediction_probabilities(sentences, test_ds)

        metrics_dict = {}
        # Find optimal logits rounding threshold for intents
        for i in np.arange(0.5, 0.96, 0.01):
            predictions = (intent_probabilities >= i).tolist()
            precision = precision_score(metrics_labels, predictions, average='micro')
            recall = recall_score(metrics_labels, predictions, average='micro')
            f1 = f1_score(metrics_labels, predictions, average='micro')
            metrics_dict[i] = [precision, recall, f1]

        max_precision = max(metrics_dict, key=lambda x: metrics_dict[x][0])
        max_recall = max(metrics_dict, key=lambda x: metrics_dict[x][1])
        max_f1_score = max(metrics_dict, key=lambda x: metrics_dict[x][2])

        logging.info(
            f'Best Threshold for F1-Score: {max_f1_score}, [Precision, Recall, F1-Score]: {metrics_dict[max_f1_score]}'
        )
        logging.info(
            f'Best Threshold for Precision: {max_precision}, [Precision, Recall, F1-Score]: {metrics_dict[max_precision]}'
        )
        logging.info(
            f'Best Threshold for Recall: {max_recall}, [Precision, Recall, F1-Score]: {metrics_dict[max_recall]}'
        )

        if metrics_dict[max_f1_score][2] > self.max_f1:
            self.max_f1 = metrics_dict[max_f1_score][2]

            logging.info(f'Setting Threshold to: {max_f1_score}')

            self.threshold = max_f1_score

    def predict_from_examples(
        self, queries: List[str], test_ds: DictConfig, threshold: float = None
    ) -> Tuple[List[List[Tuple[str, float]]], List[str], List[List[int]]]:
        """
        Get prediction for the queries (intent and slots)


        Args:
            queries: text sequences
            test_ds: Dataset configuration section.
            threshold: Threshold for rounding prediction logits
        
        Returns:
            predicted_intents: model intent predictions with their probabilities
                Example:  [[('flight', 0.84)], [('airfare', 0.54), 
                            ('flight', 0.73), ('meal', 0.24)]]
            predicted_slots: model slot predictions
                Example:  ['O B-depart_date.month_name B-depart_date.day_number',
                           'O O B-flight_stop O O O']

            predicted_vector: model intent predictions for each individual query. Binary values within each list 
                indicate whether a class is prediced for the given query (1 for True, 0 for False)
                Example: [[1,0,0,0,0,0], [0,0,1,0,0,0]]
        """
        predicted_intents = []

        if threshold is None:
            threshold = self.threshold
        logging.info(f'Using threshold = {threshold}')

        predicted_slots = []
        predicted_vector = []

        mode = self.training
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"

            # Retrieve intent and slot vocabularies from configuration.
            intent_labels = self.cfg.data_desc.intent_labels
            slot_labels = self.cfg.data_desc.slot_labels

            # Switch model to evaluation mode
            self.eval()
            self.to(device)

            # Dataset.
            infer_datalayer = self._setup_infer_dataloader(queries, test_ds)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask = batch

                intent_logits, slot_logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                # predict intents and slots for these examples
                # intents
                intent_preds = tensor2list(torch.sigmoid(intent_logits))
                # convert numerical outputs to Intent and Slot labels from the dictionaries
                for intents in intent_preds:
                    intent_lst = []
                    temp_list = []
                    for intent_num, probability in enumerate(intents):
                        if probability >= threshold:
                            intent_lst.append((intent_labels[int(intent_num)], round(probability, 2)))
                            temp_list.append(1)
                        else:
                            temp_list.append(0)

                    predicted_vector.append(temp_list)
                    predicted_intents.append(intent_lst)

                # slots
                slot_preds = torch.argmax(slot_logits, axis=-1)
                temp_slots_preds = []

                for slot_preds_query, mask_query in zip(slot_preds, subtokens_mask):
                    temp_slots = ""
                    query_slots = ""
                    for slot, mask in zip(slot_preds_query, mask_query):
                        if mask == 1:
                            if slot < len(slot_labels):
                                query_slots += slot_labels[int(slot)] + " "
                                temp_slots += f"{slot} "
                            else:
                                query_slots += "Unknown_slot "
                                temp_slots += "0 "
                    predicted_slots.append(query_slots.strip())
                    temp_slots_preds.append(temp_slots)

        finally:
            # set mode back to its original value
            self.train(mode=mode)

        return predicted_intents, predicted_slots, predicted_vector

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        To be added
        """
        result = []
        return result
class IntentSlotClassificationModel(NLPModel):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """

        self.data_dir = cfg.data_dir
        self.max_seq_length = cfg.language_model.max_seq_length

        self.data_desc = IntentSlotDataDesc(
            data_dir=cfg.data_dir,
            modes=[cfg.train_ds.prefix, cfg.validation_ds.prefix])

        self._setup_tokenizer(cfg.tokenizer)
        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # initialize Bert model

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=self.data_desc.num_intents,
            num_slots=self.data_desc.num_slots,
            dropout=cfg.head.fc_dropout,
            num_layers=cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[cfg.intent_loss_weight, 1.0 - cfg.intent_loss_weight])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            self.data_desc.num_intents, self.data_desc.intents_label_ids)
        self.slot_classification_report = ClassificationReport(
            self.data_desc.num_slots, self.data_desc.slots_label_ids)

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        intent_logits, slot_logits = self.classifier(
            hidden_states=hidden_states)
        return intent_logits, slot_logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        train_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        tensorboard_logs = {
            'train_loss': train_loss,
            'lr': self._optimizer.param_groups[0]['lr']
        }
        return {'loss': train_loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels = batch
        intent_logits, slot_logits = self(input_ids=input_ids,
                                          token_type_ids=input_type_ids,
                                          attention_mask=input_mask)

        # calculate combined loss for intents and slots
        intent_loss = self.intent_loss(logits=intent_logits,
                                       labels=intent_labels)
        slot_loss = self.slot_loss(logits=slot_logits,
                                   labels=slot_labels,
                                   loss_mask=loss_mask)
        val_loss = self.total_loss(loss_1=intent_loss, loss_2=slot_loss)

        # calculate accuracy metrics for intents and slot reporting
        # intents
        preds = torch.argmax(intent_logits, axis=-1)
        intent_tp, intent_fp, intent_fn = self.intent_classification_report(
            preds, intent_labels)
        # slots
        subtokens_mask = subtokens_mask > 0.5
        preds = torch.argmax(slot_logits, axis=-1)[subtokens_mask]
        slot_labels = slot_labels[subtokens_mask]
        slot_tp, slot_fp, slot_fn = self.slot_classification_report(
            preds, slot_labels)

        tensorboard_logs = {
            'val_loss': val_loss,
            'intent_tp': intent_tp,
            'intent_fn': intent_fn,
            'intent_fp': intent_fp,
            'slot_tp': slot_tp,
            'slot_fn': slot_fn,
            'slot_fp': slot_fp,
        }

        return {'val_loss': val_loss, 'log': tensorboard_logs}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report (separately for intents and slots)
        tp = torch.sum(torch.stack([x['log']['intent_tp'] for x in outputs]),
                       0)
        fn = torch.sum(torch.stack([x['log']['intent_fn'] for x in outputs]),
                       0)
        fp = torch.sum(torch.stack([x['log']['intent_fp'] for x in outputs]),
                       0)
        intent_precision, intent_recall, intent_f1 = self.intent_classification_report.get_precision_recall_f1(
            tp, fn, fp, mode='micro')

        tp = torch.sum(torch.stack([x['log']['slot_tp'] for x in outputs]), 0)
        fn = torch.sum(torch.stack([x['log']['slot_fn'] for x in outputs]), 0)
        fp = torch.sum(torch.stack([x['log']['slot_fp'] for x in outputs]), 0)
        slot_precision, slot_recall, slot_f1 = self.slot_classification_report.get_precision_recall_f1(
            tp, fn, fp, mode='micro')

        tensorboard_logs = {
            'val_loss': avg_loss,
            'intent_precision': intent_precision,
            'intent_recall': intent_recall,
            'intent_f1': intent_f1,
            'slot_precision': slot_precision,
            'slot_recall': slot_recall,
            'slot_f1': slot_f1,
        }
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            tokenizer_model=cfg.tokenizer_model,
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            vocab_file=cfg.vocab_file,
        )
        self.tokenizer = tokenizer

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig):
        input_file = f'{self.data_dir}/{cfg.prefix}.tsv'
        slot_file = f'{self.data_dir}/{cfg.prefix}_slots.tsv'

        if not (os.path.exists(input_file) and os.path.exists(slot_file)):
            raise FileNotFoundError(
                f'{input_file} or {slot_file} not found. Please refer to the documentation for the right format \
                 of Intents and Slots files.')

        dataset = IntentSlotClassificationDataset(
            input_file=input_file,
            slot_file=slot_file,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_seq_length,
            num_samples=cfg.num_samples,
            pad_label=self.data_desc.pad_label,
            ignore_extra_tokens=self._cfg.ignore_extra_tokens,
            ignore_start_end=self._cfg.ignore_start_end,
        )

        return DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
            drop_last=cfg.drop_last,
            collate_fn=dataset.collate_fn,
        )

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="Joint_Intent_Slot_Assistant",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/Joint_Intent_Slot_Assistant.nemo",
            description=
            "This models is trained on this https://github.com/xliuhw/NLU-Evaluation-Data dataset which includes 64 various intents and 55 slots. Final Intent accuracy is about 87%, Slot accuracy is about 89%.",
        )
        result.append(model)
        return result
class DialogueNearestNeighbourModel(NLPModel):
    """Dialogue Nearest Neighbour Model identifies the intent of an utterance using the cosine similarity between sentence embeddings of the utterance and various label descriptions """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self.cfg = cfg
        super().__init__(cfg=cfg, trainer=trainer)
        if self.cfg.library == "huggingface":
            self.language_model = AutoModel.from_pretrained(
                self.cfg.language_model.pretrained_model_name)

    def _setup_dataloader_from_config(
            self, cfg: DictConfig,
            dataset_split) -> 'torch.utils.data.DataLoader':
        if self._cfg.dataset.task == "zero_shot":
            self.data_processor = DialogueAssistantDataProcessor(
                self.cfg.data_dir, self.tokenizer, cfg=self.cfg.dataset)
        elif self._cfg.dataset.task == "design":
            self.data_processor = DialogueDesignDataProcessor(
                data_dir=self._cfg.dataset.data_dir,
                tokenizer=self.tokenizer,
                cfg=self._cfg.dataset)
        elif self._cfg.dataset.task == 'sgd':
            self.data_processor = DialogueSGDDataProcessor(
                data_dir=self._cfg.dataset.data_dir,
                dialogues_example_dir=self._cfg.dataset.dialogues_example_dir,
                tokenizer=self.tokenizer,
                cfg=self._cfg.dataset,
            )
        else:
            raise ValueError(
                "Only zero_shot, design and sgd supported for Zero Shot Intent Model"
            )

        dataset = DialogueNearestNeighbourDataset(
            dataset_split,
            self.data_processor,
            self.tokenizer,
            self.cfg.
            dataset,  # this is the model.dataset cfg, which is diff from train_ds cfg etc
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )

    def forward(self, input_ids, attention_mask):
        if self.cfg.library == 'huggingface':
            output = self.language_model(input_ids=input_ids,
                                         attention_mask=attention_mask)
        return output

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx, mode='test')

    @staticmethod
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[
            0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(
            token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded,
                         1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def validation_step(self, batch, batch_idx, mode='val'):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_mask, labels = batch
        preds = []
        gts = []
        inputs = []
        for i in range(input_ids.size(0)):
            output = self.forward(input_ids=input_ids[i],
                                  attention_mask=input_mask[i])
            sentence_embeddings = DialogueNearestNeighbourModel.mean_pooling(
                output, input_mask[i])
            sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
            cos_sim = F.cosine_similarity(sentence_embeddings[:1, :],
                                          sentence_embeddings[1:, :])
            pred = torch.argmax(cos_sim).item() + 1
            gt = torch.argmax(labels[i][1:]).item() + 1

            preds.append(input_ids[i, pred])
            gts.append(input_ids[i, gt])
            inputs.append(input_ids[i, 0])

        return {
            'preds': torch.stack(preds),
            'labels': torch.stack(gts),
            'inputs': torch.stack(inputs)
        }

    def multi_test_epoch_end(self, outputs, dataloader_idx):
        return self.validation_epoch_end(outputs)

    def validation_epoch_end(self, outputs):
        """
        Get metrics based on the candidate label with the highest predicted likelihood and the ground truth label for intent
        """
        output_preds = torch.cat([output['preds'] for output in outputs],
                                 dim=0)
        output_labels = torch.cat([output['labels'] for output in outputs],
                                  dim=0)
        inputs = torch.cat([output['inputs'] for output in outputs], dim=0)

        decoded_preds = self.tokenizer.tokenizer.batch_decode(
            output_preds, skip_special_tokens=True)
        decoded_labels = self.tokenizer.tokenizer.batch_decode(
            output_labels, skip_special_tokens=True)
        decoded_inputs = self.tokenizer.tokenizer.batch_decode(
            inputs, skip_special_tokens=True)

        prompt_len = len(self.cfg.dataset.prompt_template.strip())
        predicted_labels = [i[prompt_len:].strip() for i in decoded_preds]
        ground_truth_labels = [i[prompt_len:].strip() for i in decoded_labels]

        os.makedirs(self.cfg.dataset.dialogues_example_dir, exist_ok=True)
        filename = os.path.join(self.cfg.dataset.dialogues_example_dir,
                                "test_predictions.jsonl")

        DialogueGenerationMetrics.save_predictions(
            filename,
            predicted_labels,
            ground_truth_labels,
            decoded_inputs,
        )

        label_to_ids = {
            label: idx
            for idx, label in enumerate(
                list(set(predicted_labels + ground_truth_labels)))
        }
        self.classification_report = ClassificationReport(
            num_classes=len(label_to_ids),
            mode='micro',
            label_ids=label_to_ids,
            dist_sync_on_step=True).to(output_preds[0].device)

        predicted_label_ids = torch.tensor([
            label_to_ids[label] for label in predicted_labels
        ]).to(output_preds[0].device)
        ground_truth_label_ids = torch.tensor([
            label_to_ids[label] for label in ground_truth_labels
        ]).to(output_preds[0].device)

        tp, fn, fp, _ = self.classification_report(predicted_label_ids,
                                                   ground_truth_label_ids)

        precision, recall, f1, report = self.classification_report.compute()
        label_acc = np.mean([
            int(predicted_labels[i] == ground_truth_labels[i])
            for i in range(len(predicted_labels))
        ])

        logging.info(report)

        self.log('unified_precision', precision)
        self.log('unified_f1', f1)
        self.log('unified_recall', recall)
        self.log('unfied_accuracy', label_acc * 100)

        self.classification_report.reset()

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config:
            logging.info(
                f"Dataloader config or file_name for the training set is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._train_dl = self._setup_dataloader_from_config(
            train_data_config, "train")

        # self.create_loss_module()

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config:
            logging.info(
                f"Dataloader config or file_path for the validation data set is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._validation_dl = self._setup_dataloader_from_config(
            val_data_config, "dev")

    def setup_multiple_test_data(self, test_data_config: Optional[DictConfig]):
        self.setup_test_data(test_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config:
            logging.info(
                f"Dataloader config or file_path for the test data set is missing, so no data loader for test is created!"
            )
            self._test_dl = None
            return
        self._test_dl = self._setup_dataloader_from_config(
            test_data_config, "test")

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained models which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        return result
Пример #30
0
class TokenClassificationModel(NLPModel, Exportable):
    """Token Classification Model with BERT, applicable for tasks such as Named Entity Recognition"""
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.classifier.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.class_weights = None
        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)

    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory and get data stats with Data Descriptor
        Weights are later used to setup loss

        Args:
            data_dir: path to data directory
        """
        self._cfg.dataset.data_dir = data_dir
        logging.info(f'Setting model.dataset.data_dir to {data_dir}.')

    def setup_loss(self, class_balancing: str = None):
        """Setup loss
           Setup or update loss.

        Args:
            class_balancing: whether to use class weights during training
        """
        if class_balancing == 'weighted_loss' and self.class_weights:
            # you may need to increase the number of epochs for convergence when using weighted_loss
            loss = CrossEntropyLoss(logits_ndim=3, weight=self.class_weights)
        else:
            loss = CrossEntropyLoss(logits_ndim=3)
        return loss

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        hidden_states = self.bert_model(input_ids=input_ids,
                                        token_type_ids=token_type_ids,
                                        attention_mask=attention_mask)
        logits = self.classifier(hidden_states=hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch
        logits = self(input_ids=input_ids,
                      token_type_ids=input_type_ids,
                      attention_mask=input_mask)
        loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': loss,
            'lr': lr,
        }

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, labels = batch
        logits = self(input_ids=input_ids,
                      token_type_ids=input_type_ids,
                      attention_mask=input_mask)
        val_loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)

        subtokens_mask = subtokens_mask > 0.5

        preds = torch.argmax(logits, axis=-1)[subtokens_mask]
        labels = labels[subtokens_mask]
        tp, fn, fp, _ = self.classification_report(preds, labels)

        return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and classification report
        precision, recall, f1, report = self.classification_report.compute()

        logging.info(report)

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('precision', precision)
        self.log('f1', f1)
        self.log('recall', recall)

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        """
        return self.validation_epoch_end(outputs)

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            vocab_file=self.register_artifact(
                config_path='tokenizer.vocab_file', src=cfg.vocab_file),
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            tokenizer_model=self.register_artifact(
                config_path='tokenizer.tokenizer_model',
                src=cfg.tokenizer_model),
        )
        self.tokenizer = tokenizer

    def setup_training_data(self,
                            train_data_config: Optional[DictConfig] = None):
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   train_data_config.labels_file)
        label_ids, label_ids_filename, self.class_weights = get_label_ids(
            label_file=labels_file,
            is_training=True,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=self._cfg.dataset.class_balancing == 'weighted_loss',
        )
        # save label maps to the config
        self._cfg.label_ids = OmegaConf.create(label_ids)
        self.register_artifact('label_ids.csv', label_ids_filename)
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self,
                              val_data_config: Optional[DictConfig] = None):
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   val_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   test_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config)

    def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader:
        """
        Setup dataloader from config
        Args:
            cfg: config for the dataloader
        Return:
            Pytorch Dataloader
        """
        dataset_cfg = self._cfg.dataset
        data_dir = dataset_cfg.data_dir

        if not os.path.exists(data_dir):
            raise FileNotFoundError(
                f"Data directory is not found at: {data_dir}.")

        text_file = os.path.join(data_dir, cfg.text_file)
        labels_file = os.path.join(data_dir, cfg.labels_file)

        if not (os.path.exists(text_file) and os.path.exists(labels_file)):
            raise FileNotFoundError(
                f'{text_file} or {labels_file} not found. The data should be split into 2 files: text.txt and \
                labels.txt. Each line of the text.txt file contains text sequences, where words are separated with \
                spaces. The labels.txt file contains corresponding labels for each word in text.txt, the labels are \
                separated with spaces. Each line of the files should follow the format:  \
                   [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \
                   [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).')
        dataset = BertTokenClassificationDataset(
            text_file=text_file,
            label_file=labels_file,
            max_seq_length=dataset_cfg.max_seq_length,
            tokenizer=self.tokenizer,
            num_samples=cfg.num_samples,
            pad_label=dataset_cfg.pad_label,
            label_ids=self._cfg.label_ids,
            ignore_extra_tokens=dataset_cfg.ignore_extra_tokens,
            ignore_start_end=dataset_cfg.ignore_start_end,
            use_cache=dataset_cfg.use_cache,
        )
        return DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=dataset_cfg.num_workers,
            pin_memory=dataset_cfg.pin_memory,
            drop_last=dataset_cfg.drop_last,
        )

    def _setup_infer_dataloader(
            self, queries: List[str],
            batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.

        Args:
            queries: text
            batch_size: batch size to use during inference

        Returns:
            A pytorch DataLoader.
        """

        dataset = BertTokenClassificationInferDataset(tokenizer=self.tokenizer,
                                                      queries=queries,
                                                      max_seq_length=-1)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    @torch.no_grad()
    def _infer(self, queries: List[str], batch_size: int = None) -> List[int]:
        """
        Get prediction for the queries
        Args:
            queries: text sequences
            batch_size: batch size to use during inference.
        Returns:
            all_preds: model predictions
        """
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        try:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # Switch model to evaluation mode
            self.eval()
            self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            for batch in infer_datalayer:
                input_ids, input_type_ids, input_mask, subtokens_mask = batch

                logits = self.forward(
                    input_ids=input_ids.to(device),
                    token_type_ids=input_type_ids.to(device),
                    attention_mask=input_mask.to(device),
                )

                subtokens_mask = subtokens_mask > 0.5
                preds = tensor2list(
                    torch.argmax(logits, axis=-1)[subtokens_mask])
                all_preds.extend(preds)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return all_preds

    def add_predictions(self,
                        queries: Union[List[str], str],
                        batch_size: int = 32) -> List[str]:
        """
        Add predicted token labels to the queries. Use this method for debugging and prototyping.
        Args:
            queries: text
            batch_size: batch size to use during inference.
        Returns:
            result: text with added entities
        """
        if queries is None or len(queries) == 0:
            return []

        result = []
        all_preds = self._infer(queries, batch_size)

        queries = [q.strip().split() for q in queries]
        num_words = [len(q) for q in queries]
        if sum(num_words) != len(all_preds):
            raise ValueError('Pred and words must have the same length')

        ids_to_labels = {v: k for k, v in self._cfg.label_ids.items()}
        start_idx = 0
        end_idx = 0
        for query in queries:
            end_idx += len(query)

            # extract predictions for the current query from the list of all predictions
            preds = all_preds[start_idx:end_idx]
            start_idx = end_idx

            query_with_entities = ''
            for j, word in enumerate(query):
                # strip out the punctuation to attach the entity tag to the word not to a punctuation mark
                # that follows the word
                if word[-1].isalpha():
                    punct = ''
                else:
                    punct = word[-1]
                    word = word[:-1]

                query_with_entities += word
                label = ids_to_labels[preds[j]]

                if label != self._cfg.dataset.pad_label:
                    query_with_entities += '[' + label + ']'
                query_with_entities += punct + ' '
            result.append(query_with_entities.strip())
        return result

    def evaluate_from_file(
        self,
        output_dir: str,
        text_file: str,
        labels_file: Optional[str] = None,
        add_confusion_matrix: Optional[bool] = False,
        normalize_confusion_matrix: Optional[bool] = True,
        batch_size: int = 1,
    ) -> None:
        """
        Run inference on data from a file, plot confusion matrix and calculate classification report.
        Use this method for final evaluation.

        Args:
            output_dir: path to output directory to store model predictions, confusion matrix plot (if set to True)
            text_file: path to file with text. Each line of the text.txt file contains text sequences, where words
                are separated with spaces: [WORD] [SPACE] [WORD] [SPACE] [WORD]
            labels_file (Optional): path to file with labels. Each line of the labels_file should contain
                labels corresponding to each word in the text_file, the labels are separated with spaces:
                [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).'
            add_confusion_matrix: whether to generate confusion matrix
            normalize_confusion_matrix: whether to normalize confusion matrix
            batch_size: batch size to use during inference.
        """
        output_dir = os.path.abspath(output_dir)

        with open(text_file, 'r') as f:
            queries = f.readlines()

        all_preds = self._infer(queries, batch_size)
        with_labels = labels_file is not None
        if with_labels:
            with open(labels_file, 'r') as f:
                all_labels_str = f.readlines()
                all_labels_str = ' '.join(
                    [labels.strip() for labels in all_labels_str])

        # writing labels and predictions to a file in output_dir is specified in the config
        os.makedirs(output_dir, exist_ok=True)
        filename = os.path.join(output_dir,
                                'infer_' + os.path.basename(text_file))
        try:
            with open(filename, 'w') as f:
                if with_labels:
                    f.write('labels\t' + all_labels_str + '\n')
                    logging.info(f'Labels save to {filename}')

                # convert labels from string label to ids
                ids_to_labels = {v: k for k, v in self._cfg.label_ids.items()}
                all_preds_str = [ids_to_labels[pred] for pred in all_preds]
                f.write('preds\t' + ' '.join(all_preds_str) + '\n')
                logging.info(f'Predictions saved to {filename}')

            if with_labels and add_confusion_matrix:
                all_labels = all_labels_str.split()
                # convert labels from string label to ids
                label_ids = self._cfg.label_ids
                all_labels = [label_ids[label] for label in all_labels]

                plot_confusion_matrix(all_labels,
                                      all_preds,
                                      output_dir,
                                      label_ids=label_ids,
                                      normalize=normalize_confusion_matrix)
                logging.info(
                    get_classification_report(all_labels, all_preds,
                                              label_ids))
        except Exception:
            logging.error(
                f'When providing a file with labels, check that all labels in {labels_file} were'
                f'seen during training.')
            raise

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="NERModel",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemonlpmodels/versions/1.0.0a5/files/NERModel.nemo",
            description=
            "The model was trained on GMB (Groningen Meaning Bank) corpus for entity recognition and achieves 74.61 F1 Macro score.",
        )
        result.append(model)
        return result

    def _prepare_for_export(self):
        return self.bert_model._prepare_for_export()

    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " TokenClassificationModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "TKCL")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])