예제 #1
0
    def prepare_data(self):
        """
        Preprocessed schema and dialogues and caches this
        """
        if self.data_prepared:
            return
        schema_config = {
            "MAX_NUM_CAT_SLOT": self._cfg.dataset.max_num_cat_slot,
            "MAX_NUM_NONCAT_SLOT": self._cfg.dataset.max_num_noncat_slot,
            "MAX_NUM_VALUE_PER_CAT_SLOT": self._cfg.dataset.max_value_per_cat_slot,
            "MAX_NUM_INTENT": self._cfg.dataset.max_num_intent,
            "NUM_TASKS": NUM_TASKS,
            "MAX_SEQ_LENGTH": self._cfg.dataset.max_seq_length,
        }
        all_schema_json_paths = []
        for dataset_split in ['train', 'test', 'dev']:
            all_schema_json_paths.append(os.path.join(self._cfg.dataset.data_dir, dataset_split, "schema.json"))
        schemas = Schema(all_schema_json_paths)

        self.dialogues_processor = SGDDataProcessor(
            task_name=self._cfg.dataset.task_name,
            data_dir=self._cfg.dataset.data_dir,
            dialogues_example_dir=self._cfg.dataset.dialogues_example_dir,
            tokenizer=self.tokenizer,
            schemas=schemas,
            schema_config=schema_config,
            subsample=self._cfg.dataset.subsample,
        )

        if is_global_rank_zero():
            overwrite_dial_files = not self._cfg.dataset.use_cache
            self.dialogues_processor.save_dialog_examples(overwrite_dial_files=overwrite_dial_files)

        self.data_prepared = True
예제 #2
0
class SGDQAModel(NLPModel):
    """Dialogue State Tracking Model SGD-QA"""

    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.decoder.output_types

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        self.data_prepared = False
        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file', cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file),
        )

        self.encoder = SGDEncoder(hidden_size=self.bert_model.config.hidden_size, dropout=self._cfg.encoder.dropout)
        self.decoder = SGDDecoder(embedding_dim=self.bert_model.config.hidden_size)
        self.loss = SGDDialogueStateLoss(reduction="mean")

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        token_embeddings = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        encoded_utterance, token_embeddings = self.encoder(hidden_states=token_embeddings)
        (
            logit_intent_status,
            logit_req_slot_status,
            logit_cat_slot_status,
            logit_cat_slot_value_status,
            logit_noncat_slot_status,
            logit_spans,
        ) = self.decoder(
            encoded_utterance=encoded_utterance, token_embeddings=token_embeddings, utterance_mask=attention_mask
        )
        return (
            logit_intent_status,
            logit_req_slot_status,
            logit_cat_slot_status,
            logit_cat_slot_value_status,
            logit_noncat_slot_status,
            logit_spans,
        )

    def training_step(self, batch, batch_idx):
        (
            example_id_num,
            service_id,
            utterance_ids,
            token_type_ids,
            attention_mask,
            intent_status,
            requested_slot_status,
            categorical_slot_status,
            categorical_slot_value_status,
            noncategorical_slot_status,
            noncategorical_slot_value_start,
            noncategorical_slot_value_end,
            start_char_idx,
            end_char_idx,
            task_mask,
        ) = batch
        (
            logit_intent_status,
            logit_req_slot_status,
            logit_cat_slot_status,
            logit_cat_slot_value_status,
            logit_noncat_slot_status,
            logit_spans,
        ) = self(input_ids=utterance_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        loss = self.loss(
            logit_intent_status=logit_intent_status,
            intent_status=intent_status,
            logit_req_slot_status=logit_req_slot_status,
            requested_slot_status=requested_slot_status,
            logit_cat_slot_status=logit_cat_slot_status,
            categorical_slot_status=categorical_slot_status,
            logit_cat_slot_value_status=logit_cat_slot_value_status,
            categorical_slot_value_status=categorical_slot_value_status,
            logit_noncat_slot_status=logit_noncat_slot_status,
            noncategorical_slot_status=noncategorical_slot_status,
            logit_spans=logit_spans,
            noncategorical_slot_value_start=noncategorical_slot_value_start,
            noncategorical_slot_value_end=noncategorical_slot_value_end,
            task_mask=task_mask,
        )
        lr = self._optimizer.param_groups[0]['lr']

        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)

        return {
            'loss': loss,
            'lr': lr,
        }

    def validation_step(self, batch: List[torch.Tensor], batch_idx: int, dataloader_idx: int = 0) -> dict:
        """
        Called at every validation step to aggregate and postprocess outputs on each GPU
        Args:
            batch: input batch at validation step
            batch_idx: batch index 
            dataloader_idx: dataloader index
        """
        loss, tensors = self.eval_step_helper(batch=batch)
        self.log(f'val_loss', loss)
        return {f'val_loss': loss, f'tensors': tensors}

    def test_step(self, batch: List[torch.Tensor], batch_idx: int, dataloader_idx: int = 0) -> dict:
        """
        Called at every test step to aggregate and postprocess outputs on each GPU
        Args:
            batch: input batch at test step
            batch_idx: batch index 
            dataloader_idx: dataloader index
        """
        loss, tensors = self.eval_step_helper(batch=batch)
        return {f'test_loss': loss, f'tensors': tensors}

    def eval_step_helper(self, batch: List[torch.Tensor]):
        """
        Helper called at every validation/test step to aggregate and postprocess outputs on each GPU
        Args:
            batch: input batch at step
        Returns:
            loss: averaged batch loss
            tensors: collection of aggregated output tensors across all GPU workers
        """
        (
            example_id_num,
            service_id,
            utterance_ids,
            token_type_ids,
            attention_mask,
            intent_status,
            requested_slot_status,
            categorical_slot_status,
            categorical_slot_value_status,
            noncategorical_slot_status,
            noncategorical_slot_value_start,
            noncategorical_slot_value_end,
            start_char_idx,
            end_char_idx,
            task_mask,
        ) = batch
        (
            logit_intent_status,
            logit_req_slot_status,
            logit_cat_slot_status,
            logit_cat_slot_value_status,
            logit_noncat_slot_status,
            logit_spans,
        ) = self(input_ids=utterance_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        loss = self.loss(
            logit_intent_status=logit_intent_status,
            intent_status=intent_status,
            logit_req_slot_status=logit_req_slot_status,
            requested_slot_status=requested_slot_status,
            logit_cat_slot_status=logit_cat_slot_status,
            categorical_slot_status=categorical_slot_status,
            logit_cat_slot_value_status=logit_cat_slot_value_status,
            categorical_slot_value_status=categorical_slot_value_status,
            logit_noncat_slot_status=logit_noncat_slot_status,
            noncategorical_slot_status=noncategorical_slot_status,
            logit_spans=logit_spans,
            noncategorical_slot_value_start=noncategorical_slot_value_start,
            noncategorical_slot_value_end=noncategorical_slot_value_end,
            task_mask=task_mask,
        )

        all_example_id_num = []
        all_service_id = []
        all_logit_intent_status = []
        all_logit_req_slot_status = []
        all_logit_cat_slot_status = []
        all_logit_cat_slot_value_status = []
        all_logit_noncat_slot_status = []
        all_logit_spans = []
        all_start_char_idx = []
        all_end_char_idx = []

        if self.trainer.gpus and self.trainer.world_size > 1:
            world_size = self.trainer.world_size
            for ind in range(world_size):
                all_example_id_num.append(torch.empty_like(example_id_num))
                all_service_id.append(torch.empty_like(service_id))
                all_logit_intent_status.append(torch.empty_like(logit_intent_status))
                all_logit_req_slot_status.append(torch.empty_like(logit_req_slot_status))
                all_logit_cat_slot_status.append(torch.empty_like(logit_cat_slot_status))
                all_logit_cat_slot_value_status.append(torch.empty_like(logit_cat_slot_value_status))
                all_logit_noncat_slot_status.append(torch.empty_like(logit_noncat_slot_status))
                all_logit_spans.append(torch.empty_like(logit_spans))
                all_start_char_idx.append(torch.empty_like(start_char_idx))
                all_end_char_idx.append(torch.empty_like(end_char_idx))

            torch.distributed.all_gather(all_example_id_num, example_id_num)
            torch.distributed.all_gather(all_service_id, service_id)
            torch.distributed.all_gather(all_logit_intent_status, logit_intent_status)
            torch.distributed.all_gather(all_logit_req_slot_status, logit_req_slot_status)
            torch.distributed.all_gather(all_logit_cat_slot_status, logit_cat_slot_status)
            torch.distributed.all_gather(all_logit_cat_slot_value_status, logit_cat_slot_value_status)
            torch.distributed.all_gather(all_logit_noncat_slot_status, logit_noncat_slot_status)
            torch.distributed.all_gather(all_logit_spans, logit_spans)
            torch.distributed.all_gather(all_start_char_idx, start_char_idx)
            torch.distributed.all_gather(all_end_char_idx, end_char_idx)
        else:
            all_example_id_num.append(example_id_num)
            all_service_id.append(service_id)
            all_logit_intent_status.append(logit_intent_status)
            all_logit_req_slot_status.append(logit_req_slot_status)
            all_logit_cat_slot_status.append(logit_cat_slot_status)
            all_logit_cat_slot_value_status.append(logit_cat_slot_value_status)
            all_logit_noncat_slot_status.append(logit_noncat_slot_status)
            all_logit_spans.append(logit_spans)
            all_start_char_idx.append(start_char_idx)
            all_end_char_idx.append(end_char_idx)

        # after this: all_x is list of tensors, of length world_size
        example_id_num = torch.cat(all_example_id_num)
        service_id = torch.cat(all_service_id)
        logit_intent_status = torch.cat(all_logit_intent_status)
        logit_req_slot_status = torch.cat(all_logit_req_slot_status)
        logit_cat_slot_status = torch.cat(all_logit_cat_slot_status)
        logit_cat_slot_value_status = torch.cat(all_logit_cat_slot_value_status)
        logit_noncat_slot_status = torch.cat(all_logit_noncat_slot_status)
        logit_spans = torch.cat(all_logit_spans)
        start_char_idx = torch.cat(all_start_char_idx)
        end_char_idx = torch.cat(all_end_char_idx)

        intent_status = torch.nn.Sigmoid()(logit_intent_status)

        # Scores are output for each requested slot.
        req_slot_status = torch.nn.Sigmoid()(logit_req_slot_status)

        # For categorical slots, the status of each slot and the predicted value are output.
        cat_slot_status_dist = torch.nn.Softmax(dim=-1)(logit_cat_slot_status)

        cat_slot_status = torch.argmax(logit_cat_slot_status, axis=-1)
        cat_slot_status_p = torch.max(cat_slot_status_dist, axis=-1)[0]
        cat_slot_value_status = torch.nn.Sigmoid()(logit_cat_slot_value_status)

        # For non-categorical slots, the status of each slot and the indices for spans are output.
        noncat_slot_status_dist = torch.nn.Softmax(dim=-1)(logit_noncat_slot_status)

        noncat_slot_status = torch.argmax(logit_noncat_slot_status, axis=-1)
        noncat_slot_status_p = torch.max(noncat_slot_status_dist, axis=-1)[0]

        softmax = torch.nn.Softmax(dim=1)

        scores = softmax(logit_spans)
        start_scores, end_scores = torch.unbind(scores, dim=-1)

        batch_size, max_num_tokens = end_scores.size()
        # Find the span with the maximum sum of scores for start and end indices.
        total_scores = torch.unsqueeze(start_scores, axis=2) + torch.unsqueeze(end_scores, axis=1)
        start_idx = torch.arange(max_num_tokens, device=total_scores.get_device()).view(1, -1, 1)
        end_idx = torch.arange(max_num_tokens, device=total_scores.get_device()).view(1, 1, -1)
        invalid_index_mask = (start_idx > end_idx).repeat(batch_size, 1, 1)
        total_scores = torch.where(
            invalid_index_mask,
            torch.zeros(total_scores.size(), device=total_scores.get_device(), dtype=total_scores.dtype),
            total_scores,
        )
        max_span_index = torch.argmax(total_scores.view(-1, max_num_tokens ** 2), axis=-1)
        max_span_p = torch.max(total_scores.view(-1, max_num_tokens ** 2), axis=-1)[0]

        span_start_index = torch.floor_divide(max_span_index, max_num_tokens)
        span_end_index = torch.fmod(max_span_index, max_num_tokens)

        tensors = {
            'example_id_num': example_id_num,
            'service_id': service_id,
            'intent_status': intent_status,
            'req_slot_status': req_slot_status,
            'cat_slot_status': cat_slot_status,
            'cat_slot_status_p': cat_slot_status_p,
            'cat_slot_value_status': cat_slot_value_status,
            'noncat_slot_status': noncat_slot_status,
            'noncat_slot_status_p': noncat_slot_status_p,
            'noncat_slot_p': max_span_p,
            'noncat_slot_start': span_start_index,
            'noncat_slot_end': span_end_index,
            'noncat_alignment_start': start_char_idx,
            'noncat_alignment_end': end_char_idx,
        }
        return loss, tensors

    def multi_validation_epoch_end(self, outputs: List[dict], dataloader_idx: int = 0):
        """
        Called at the end of validation to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each validation step
            dataloader_idx: dataloader index
        """
        avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean()
        split = self._validation_names[dataloader_idx][:-1]
        dataloader = self._validation_dl[dataloader_idx]
        metrics = self.multi_eval_epoch_end_helper(outputs=outputs, split=split, dataloader=dataloader)

        for k, v in metrics.items():
            self.log(f'{split}_{k}', v, rank_zero_only=True)

        self.log(f'val_loss', avg_loss, prog_bar=True, rank_zero_only=True)

    def multi_test_epoch_end(self, outputs: List[dict], dataloader_idx: int = 0):
        """
        Called at the end of test to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each test step
            dataloader_idx: dataloader index
        """
        avg_loss = torch.stack([x[f'test_loss'] for x in outputs]).mean()
        split = self._test_names[dataloader_idx][:-1]
        dataloader = self._test_dl[dataloader_idx]
        metrics = self.multi_eval_epoch_end_helper(outputs=outputs, split=split, dataloader=dataloader)

        for k, v in metrics.items():
            self.log(f'{split}_{k}', v, rank_zero_only=True)

        self.log(f'test_loss', avg_loss, prog_bar=True, rank_zero_only=True)

    def multi_eval_epoch_end_helper(
        self, outputs: List[dict], split: str, dataloader: torch.utils.data.DataLoader
    ) -> dict:
        """
        Helper called at the end of evaluation to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each step
            split: data split
            dataloader: dataloader
        Returns:
            metrics: metrics collection
        """

        def get_str_example_id(split: str, ids_to_service_names_dict: dict, example_id_num: torch.Tensor) -> str:
            """
            Constructs string representation of example ID
            Args:
                split: evaluation data split
                ids_to_service_names_dict: id to service name mapping
                example_id_num: tensor example id
            """

            def format_turn_id(ex_id_num):
                dialog_id_1, dialog_id_2, turn_id, service_id, model_task_id, slot_intent_id, value_id = ex_id_num
                return "{}-{}_{:05d}-{:02d}-{}-{}-{}-{}".format(
                    split,
                    dialog_id_1,
                    dialog_id_2,
                    turn_id,
                    ids_to_service_names_dict[service_id],
                    model_task_id,
                    slot_intent_id,
                    value_id,
                )

            return list(map(format_turn_id, tensor2list(example_id_num)))

        def combine_predictions_in_example(predictions: dict, batch_size: int):
            '''
            Combines predicted values to a single example. 
            Args:
                predictions: predictions ordered by keys then batch
                batch_size: batch size
            Returns:
                examples_preds: predictions ordered by batch then key
            '''
            examples_preds = [{} for _ in range(batch_size)]
            for k, v in predictions.items():
                if k != 'example_id':
                    v = torch.chunk(v, batch_size)

                for i in range(batch_size):
                    if k == 'example_id':
                        examples_preds[i][k] = v[i]
                    else:
                        examples_preds[i][k] = v[i].view(-1)
            return examples_preds

        example_id_num = torch.cat([x[f'tensors']['example_id_num'] for x in outputs])
        service_id = torch.cat([x[f'tensors']['service_id'] for x in outputs])
        intent_status = torch.cat([x[f'tensors']['intent_status'] for x in outputs])
        req_slot_status = torch.cat([x[f'tensors']['req_slot_status'] for x in outputs])
        cat_slot_status = torch.cat([x[f'tensors']['cat_slot_status'] for x in outputs])
        cat_slot_status_p = torch.cat([x[f'tensors']['cat_slot_status_p'] for x in outputs])
        cat_slot_value_status = torch.cat([x[f'tensors']['cat_slot_value_status'] for x in outputs])
        noncat_slot_status = torch.cat([x[f'tensors']['noncat_slot_status'] for x in outputs])
        noncat_slot_status_p = torch.cat([x[f'tensors']['noncat_slot_status_p'] for x in outputs])
        noncat_slot_p = torch.cat([x[f'tensors']['noncat_slot_p'] for x in outputs])
        noncat_slot_start = torch.cat([x[f'tensors']['noncat_slot_start'] for x in outputs])
        noncat_slot_end = torch.cat([x[f'tensors']['noncat_slot_end'] for x in outputs])
        noncat_alignment_start = torch.cat([x[f'tensors']['noncat_alignment_start'] for x in outputs])
        noncat_alignment_end = torch.cat([x[f'tensors']['noncat_alignment_end'] for x in outputs])

        ids_to_service_names_dict = self.dialogues_processor.schemas._services_id_to_vocab
        example_id = get_str_example_id(dataloader.dataset, ids_to_service_names_dict, example_id_num)

        metrics = {}
        try:
            prediction_dir = self.trainer.log_dir if self.trainer.log_dir is not None else ""
        except:
            prediction_dir = ""

        if self.trainer.global_rank == 0:
            prediction_dir = os.path.join(
                prediction_dir, 'predictions', 'pred_res_{}_{}'.format(split, self._cfg.dataset.task_name)
            )
            os.makedirs(prediction_dir, exist_ok=True)

            input_json_files = SGDDataProcessor.get_dialogue_files(
                self._cfg.dataset.data_dir, split, self._cfg.dataset.task_name
            )

            predictions = {}
            predictions['example_id'] = example_id
            predictions['service_id'] = service_id
            predictions['intent_status'] = intent_status
            predictions['req_slot_status'] = req_slot_status
            predictions['cat_slot_status'] = cat_slot_status
            predictions['cat_slot_status_p'] = cat_slot_status_p
            predictions['cat_slot_value_status'] = cat_slot_value_status
            predictions['noncat_slot_status'] = noncat_slot_status
            predictions['noncat_slot_status_p'] = noncat_slot_status_p
            predictions['noncat_slot_p'] = noncat_slot_p
            predictions['noncat_slot_start'] = noncat_slot_start
            predictions['noncat_slot_end'] = noncat_slot_end
            predictions['noncat_alignment_start'] = noncat_alignment_start
            predictions['noncat_alignment_end'] = noncat_alignment_end

            in_domain_services = get_in_domain_services(
                os.path.join(self._cfg.dataset.data_dir, split, "schema.json"),
                self.dialogues_processor.get_seen_services("train"),
            )
            predictions = combine_predictions_in_example(predictions, service_id.shape[0])

            # write predictions to file in Dstc8/SGD format
            write_predictions_to_file(
                predictions,
                input_json_files,
                output_dir=prediction_dir,
                schemas=self.dialogues_processor.schemas,
                state_tracker=self._cfg.dataset.state_tracker,
                eval_debug=False,
                in_domain_services=in_domain_services,
            )
            metrics = evaluate(
                prediction_dir,
                self._cfg.dataset.data_dir,
                split,
                in_domain_services,
                joint_acc_across_turn=self._cfg.dataset.joint_acc_across_turn,
                use_fuzzy_match=self._cfg.dataset.use_fuzzy_match,
            )

        return metrics

    def prepare_data(self):
        """
        Preprocessed schema and dialogues and caches this
        """
        if self.data_prepared:
            return
        schema_config = {
            "MAX_NUM_CAT_SLOT": self._cfg.dataset.max_num_cat_slot,
            "MAX_NUM_NONCAT_SLOT": self._cfg.dataset.max_num_noncat_slot,
            "MAX_NUM_VALUE_PER_CAT_SLOT": self._cfg.dataset.max_value_per_cat_slot,
            "MAX_NUM_INTENT": self._cfg.dataset.max_num_intent,
            "NUM_TASKS": NUM_TASKS,
            "MAX_SEQ_LENGTH": self._cfg.dataset.max_seq_length,
        }
        all_schema_json_paths = []
        for dataset_split in ['train', 'test', 'dev']:
            all_schema_json_paths.append(os.path.join(self._cfg.dataset.data_dir, dataset_split, "schema.json"))
        schemas = Schema(all_schema_json_paths)

        self.dialogues_processor = SGDDataProcessor(
            task_name=self._cfg.dataset.task_name,
            data_dir=self._cfg.dataset.data_dir,
            dialogues_example_dir=self._cfg.dataset.dialogues_example_dir,
            tokenizer=self.tokenizer,
            schemas=schemas,
            schema_config=schema_config,
            subsample=self._cfg.dataset.subsample,
        )

        if is_global_rank_zero():
            overwrite_dial_files = not self._cfg.dataset.use_cache
            self.dialogues_processor.save_dialog_examples(overwrite_dial_files=overwrite_dial_files)

        self.data_prepared = True

    def update_data_dirs(self, data_dir: str, dialogues_example_dir: str):
        """
        Update data directories

        Args:
            data_dir: path to data directory
            dialogues_example_dir: path to preprocessed dialogues example directory, if not exists will be created.
        """
        if not os.path.exists(data_dir):
            raise ValueError(f"{data_dir} is not found")
        self._cfg.dataset.data_dir = data_dir
        self._cfg.dataset.dialogues_example_dir = dialogues_example_dir
        logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
        logging.info(f'Setting model.dataset.dialogues_example_dir to {dialogues_example_dir}.')

    def setup_training_data(self, train_data_config: Optional[DictConfig] = None):
        self.prepare_data()
        self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, split=train_data_config.ds_item)

    def setup_validation_data(self, val_data_config: Optional[DictConfig] = None):
        self.prepare_data()
        self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, split=val_data_config.ds_item)

    def setup_test_data(self, test_data_config: Optional[DictConfig] = None):
        self.prepare_data()
        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config, split=test_data_config.ds_item)

    def _setup_dataloader_from_config(self, cfg: DictConfig, split: str) -> DataLoader:
        dataset_cfg = self._cfg.dataset
        data_dir = dataset_cfg.data_dir

        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Data directory is not found at: {data_dir}.")

        dataset = SGDDataset(dataset_split=split, dialogues_processor=self.dialogues_processor)

        dl = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            collate_fn=dataset.collate_fn,
            drop_last=cfg.drop_last,
            shuffle=cfg.shuffle,
            num_workers=cfg.num_workers,
            pin_memory=cfg.pin_memory,
        )
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []

        result.append(
            PretrainedModelInfo(
                pretrained_model_name="sgdqa_bertbasecased",
                location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/sgdqa_bertbasecased/versions/1.0.0/files/sgdqa_bertbasecased.nemo",
                description="Dialogue State Tracking model finetuned from NeMo BERT Base Cased on Google SGD dataset which has a joint goal accuracy of 59.72% on dev set and 45.85% on test set.",
            )
        )
        return result
예제 #3
0
    def multi_eval_epoch_end_helper(
        self, outputs: List[dict], split: str, dataloader: torch.utils.data.DataLoader
    ) -> dict:
        """
        Helper called at the end of evaluation to post process outputs into human readable format
        Args:
            outputs: list of individual outputs of each step
            split: data split
            dataloader: dataloader
        Returns:
            metrics: metrics collection
        """

        def get_str_example_id(split: str, ids_to_service_names_dict: dict, example_id_num: torch.Tensor) -> str:
            """
            Constructs string representation of example ID
            Args:
                split: evaluation data split
                ids_to_service_names_dict: id to service name mapping
                example_id_num: tensor example id
            """

            def format_turn_id(ex_id_num):
                dialog_id_1, dialog_id_2, turn_id, service_id, model_task_id, slot_intent_id, value_id = ex_id_num
                return "{}-{}_{:05d}-{:02d}-{}-{}-{}-{}".format(
                    split,
                    dialog_id_1,
                    dialog_id_2,
                    turn_id,
                    ids_to_service_names_dict[service_id],
                    model_task_id,
                    slot_intent_id,
                    value_id,
                )

            return list(map(format_turn_id, tensor2list(example_id_num)))

        def combine_predictions_in_example(predictions: dict, batch_size: int):
            '''
            Combines predicted values to a single example. 
            Args:
                predictions: predictions ordered by keys then batch
                batch_size: batch size
            Returns:
                examples_preds: predictions ordered by batch then key
            '''
            examples_preds = [{} for _ in range(batch_size)]
            for k, v in predictions.items():
                if k != 'example_id':
                    v = torch.chunk(v, batch_size)

                for i in range(batch_size):
                    if k == 'example_id':
                        examples_preds[i][k] = v[i]
                    else:
                        examples_preds[i][k] = v[i].view(-1)
            return examples_preds

        example_id_num = torch.cat([x[f'tensors']['example_id_num'] for x in outputs])
        service_id = torch.cat([x[f'tensors']['service_id'] for x in outputs])
        intent_status = torch.cat([x[f'tensors']['intent_status'] for x in outputs])
        req_slot_status = torch.cat([x[f'tensors']['req_slot_status'] for x in outputs])
        cat_slot_status = torch.cat([x[f'tensors']['cat_slot_status'] for x in outputs])
        cat_slot_status_p = torch.cat([x[f'tensors']['cat_slot_status_p'] for x in outputs])
        cat_slot_value_status = torch.cat([x[f'tensors']['cat_slot_value_status'] for x in outputs])
        noncat_slot_status = torch.cat([x[f'tensors']['noncat_slot_status'] for x in outputs])
        noncat_slot_status_p = torch.cat([x[f'tensors']['noncat_slot_status_p'] for x in outputs])
        noncat_slot_p = torch.cat([x[f'tensors']['noncat_slot_p'] for x in outputs])
        noncat_slot_start = torch.cat([x[f'tensors']['noncat_slot_start'] for x in outputs])
        noncat_slot_end = torch.cat([x[f'tensors']['noncat_slot_end'] for x in outputs])
        noncat_alignment_start = torch.cat([x[f'tensors']['noncat_alignment_start'] for x in outputs])
        noncat_alignment_end = torch.cat([x[f'tensors']['noncat_alignment_end'] for x in outputs])

        ids_to_service_names_dict = self.dialogues_processor.schemas._services_id_to_vocab
        example_id = get_str_example_id(dataloader.dataset, ids_to_service_names_dict, example_id_num)

        metrics = {}
        try:
            prediction_dir = self.trainer.log_dir if self.trainer.log_dir is not None else ""
        except:
            prediction_dir = ""

        if self.trainer.global_rank == 0:
            prediction_dir = os.path.join(
                prediction_dir, 'predictions', 'pred_res_{}_{}'.format(split, self._cfg.dataset.task_name)
            )
            os.makedirs(prediction_dir, exist_ok=True)

            input_json_files = SGDDataProcessor.get_dialogue_files(
                self._cfg.dataset.data_dir, split, self._cfg.dataset.task_name
            )

            predictions = {}
            predictions['example_id'] = example_id
            predictions['service_id'] = service_id
            predictions['intent_status'] = intent_status
            predictions['req_slot_status'] = req_slot_status
            predictions['cat_slot_status'] = cat_slot_status
            predictions['cat_slot_status_p'] = cat_slot_status_p
            predictions['cat_slot_value_status'] = cat_slot_value_status
            predictions['noncat_slot_status'] = noncat_slot_status
            predictions['noncat_slot_status_p'] = noncat_slot_status_p
            predictions['noncat_slot_p'] = noncat_slot_p
            predictions['noncat_slot_start'] = noncat_slot_start
            predictions['noncat_slot_end'] = noncat_slot_end
            predictions['noncat_alignment_start'] = noncat_alignment_start
            predictions['noncat_alignment_end'] = noncat_alignment_end

            in_domain_services = get_in_domain_services(
                os.path.join(self._cfg.dataset.data_dir, split, "schema.json"),
                self.dialogues_processor.get_seen_services("train"),
            )
            predictions = combine_predictions_in_example(predictions, service_id.shape[0])

            # write predictions to file in Dstc8/SGD format
            write_predictions_to_file(
                predictions,
                input_json_files,
                output_dir=prediction_dir,
                schemas=self.dialogues_processor.schemas,
                state_tracker=self._cfg.dataset.state_tracker,
                eval_debug=False,
                in_domain_services=in_domain_services,
            )
            metrics = evaluate(
                prediction_dir,
                self._cfg.dataset.data_dir,
                split,
                in_domain_services,
                joint_acc_across_turn=self._cfg.dataset.joint_acc_across_turn,
                use_fuzzy_match=self._cfg.dataset.use_fuzzy_match,
            )

        return metrics