Ejemplo n.º 1
0
def eval_metrics(
    model_output: Dict[str, dict], data_path: str, top_k: int
) -> Dict[str, float]:
    """Calculate `turn_inform` accuracy, `turn_request` accuracy and `joint_goal` accuracy

    Args:
        data_path: save path of bad cases
        model_output: reformatted results containing preds and ground truth
                      according to dialogue id and turn id
        top_k: take top k prediction labels

    Returns:
        metrics
    """
    inform = []
    request = []
    joint_goal = []

    inform_request_dict = {}
    for dialogue_idx, dia in model_output.items():
        turn_dict = defaultdict(dict)
        for turn_id, turn in dia.items():
            logits = turn["logits"]
            preds = turn["preds"]
            labels = turn["labels"]

            preds = rank_values(logits, preds, top_k)

            gold_request, gold_inform = get_inf_req(labels)
            pred_request, pred_inform = get_inf_req(preds)

            turn_dict[turn_id]["pred_inform"] = [list(dsv) for dsv in pred_inform]
            turn_dict[turn_id]["gold_inform"] = [list(dsv) for dsv in gold_inform]
            turn_dict[turn_id]["pred_request"] = [list(dsv) for dsv in pred_request]
            turn_dict[turn_id]["gold_request"] = [list(dsv) for dsv in gold_request]
            request.append(gold_request == pred_request)
            inform.append(gold_inform == pred_inform)

            # 只留下 inform intent,去掉 general intent
            pred_recovered = set(
                [(d, s, v) for d, s, v in pred_inform if not s == v == "none"]
            )
            gold_recovered = set(turn["belief_state"])
            joint_goal.append(pred_recovered == gold_recovered)

        inform_request_dict.update({dialogue_idx: turn_dict})

    with open(os.path.join(data_path, "bad_cases.json"), "w", encoding="utf8") as f:
        json.dump(inform_request_dict, f, indent=2, ensure_ascii=False)

    return {
        "turn_inform": round(float(np.mean(inform)), 3),
        "turn_request": round(float(np.mean(request)), 3),
        "joint_goal": round(float(np.mean(joint_goal)), 3),
    }
Ejemplo n.º 2
0
    def forward(self, sys_uttr: str, usr_utter: str) -> List[tuple]:
        """Bert model forward and rank output triple labels.

        Args:
            sys_uttr: response of previous system turn
            usr_utter: previous turn user's utterance

        Returns:
            a list of triple labels, (domain, slot, value)
        """
        pred_labels = []
        true_logits = []
        dataloader = self.preprocess(sys_uttr, usr_utter)
        pbar = tqdm(enumerate(dataloader),
                    total=len(dataloader),
                    desc="Inferring")

        with torch.no_grad():
            for step, batch in pbar:
                inputs = {
                    k: v.to(self.config["device"])
                    for k, v in list(batch.items())[:3]
                }
                logits = self.model(**inputs)[0]

                max_logits, preds = [
                    item.cpu().tolist() for item in logits.max(dim=1)
                ]

                for i, (pred, logit) in enumerate(zip(preds, max_logits)):
                    triple = (
                        batch["domains"][i],
                        batch["slots"][i],
                        batch["values"][i],
                    )
                    if pred == 1:
                        true_logits.append(logit)
                        pred_labels.append(triple)

        pred_labels = rank_values(true_logits, pred_labels,
                                  self.config["top_k"])

        return pred_labels