Exemple #1
0
def eval_metrics(gold_pred_sys_das):
    tp, fp, fn = 0, 0, 0
    joint_acc = total = 0
    bad_case = {}

    for dia_id, sess in gold_pred_sys_das.items():
        for turn_id, turn in sess.items():
            if not turn['gold_sys_act'] and not turn['pred_sys_act']:
                joint_acc += 1
            elif not turn['pred_sys_act']:
                fn += len(turn['gold_sys_act'])
            elif not turn['gold_sys_act']:
                fp += len(turn['pred_sys_act'])
            # 当 intent 为 Recommend 或者 slot 为 周边xx 时,数据集中给出的数量并没有规律,
            # 因此,只要碰上此类,都认为正确,预测的结果基本包含数据集中的结果
            elif ((turn['gold_sys_act'][0][0] == turn['pred_sys_act'][0][0] ==
                   'Recommend')
                  or turn['gold_sys_act'][0][2].startswith('周边')):
                joint_acc += 1
                tp += len(turn['gold_sys_act'])
            else:
                gold = set(turn['gold_sys_act'])
                pred = set(turn['pred_sys_act'])

                if gold != pred:
                    if dia_id not in bad_case:
                        bad_case[dia_id] = {}
                    bad_case[dia_id][str(turn_id)] = {
                        'gold_sys_act': turn['gold_sys_act'],
                        'pred_sys_act': turn['pred_sys_act']
                    }
                else:
                    joint_acc += 1

                tp += len(gold & pred)
                fn += len(gold - pred)
                fp += len(pred - gold)

            total += 1

    precision = tp / (tp + fp) if (tp + fp) != 0 else 0
    recall = tp / (tp + fn) if (tp + fn) != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision +
                                                           recall) != 0 else 0
    joint_acc /= total

    output_path = os.path.join(
        get_data_path(),
        'crosswoz/policy_rule_single_domain_data/bad_case.json')
    dump_json(bad_case, output_path)

    return f1, precision, recall, joint_acc
Exemple #2
0
def get_single_domain_examples(input_file_path, output_file_path):
    file_in_zip = os.path.basename(input_file_path).rsplit('.', maxsplit=1)[0]
    dataset = read_zipped_json(input_file_path, file_in_zip)

    single_domain_examples = {
        id_: dialogue
        for id_, dialogue in dataset.items() if dialogue['type'] == '单领域'
    }

    print(
        f'{file_in_zip} total has {len(single_domain_examples)} single domain examples'
    )

    dump_json(single_domain_examples, output_file_path)
Exemple #3
0
    def eval_test(self) -> None:
        """Loading best model to evaluate tests dataset."""
        if self.best_model_path is not None:
            if hasattr(self.model, "module"):
                self.model.module = BertForSequenceClassification.from_pretrained(
                    self.best_model_path)
            else:
                self.model = BertForSequenceClassification.from_pretrained(
                    self.best_model_path)
            self.model.to(self.config["device"])

        _, prediction_results = self.evaluation(self.test_dataloader,
                                                mode="tests")
        dump_json(
            prediction_results,
            os.path.join(self.config["data_path"], "prediction.json"),
        )
Exemple #4
0
    def eval_test(self) -> None:
        """Loading best model to evaluate test dataset.
        use json dump prediction results described above
        """
        if self.best_model_path is not None:
            if hasattr(self.model, 'module'):
                self.model.module = BertForSequenceClassification.from_pretrained(
                    self.best_model_path)
            else:
                self.model = BertForSequenceClassification.from_pretrained(
                    self.best_model_path)
            self.model.to(self.config['device'])

        _, prediction_results = self.evaluation(self.test_dataloader,
                                                mode='test')
        dump_json(prediction_results,
                  os.path.join(self.config['data_path'], 'prediction.json'))
Exemple #5
0
    def get_act_ontology(raw_data_path: str, output_path: str) -> None:
        """Generate action ontology from raw train data.

        Args:
            raw_data_path: raw train data path
            output_path: save path of action ontology file
        """
        raw_data = read_zipped_json(raw_data_path, "train.json")

        act_ontology = set()
        for dial_id, dial in tqdm(raw_data.items(),
                                  desc="Generate action ontology ..."):
            for turn_id, turn in enumerate(dial["messages"]):
                if turn["role"] == "sys" and turn["dialog_act"]:
                    for da in turn["dialog_act"]:
                        act = "-".join([da[1], da[0], da[2]])
                        act_ontology.add(act)

        dump_json(list(act_ontology), output_path)
Exemple #6
0
def cache_data(data_type: str, examples: List[Dict[str, list]],
               output_path: str) -> None:
    """Save processed data."""
    dump_json(examples, output_path)
    print(f"Saving preprocessed {data_type} into {output_path} ...")