Esempio n. 1
0
    def test_requested_slots_f1(self):
        # (1) Test on oracle frame.
        requestable_slots_f1_oracle = metrics.get_requested_slots_f1(
                self.frame_ref, self.frame_ref)
        # Ground truth values for oracle prediction are all 1.0.
        self._assert_dicts_almost_equal(
                {k: 1.0 for k in self.known_metrics[metrics.REQUESTED_SLOTS_F1]},
                requestable_slots_f1_oracle._asdict())

        # (2) Test on a previously known frame.
        requested_slots_f1_hyp = metrics.get_requested_slots_f1(
                self.frame_ref, self.frame_hyp)
        self._assert_dicts_almost_equal(
                self.known_metrics[metrics.REQUESTED_SLOTS_F1],
                requested_slots_f1_hyp._asdict())
Esempio n. 2
0
def get_metrics(dataset_ref, dataset_hyp, service_schemas, in_domain_services):
    """Calculate the DSTC8 metrics.

  Args:
    dataset_ref: The ground truth dataset represented as a dict mapping dialogue
      id to the corresponding dialogue.
    dataset_hyp: The predictions in the same format as `dataset_ref`.
    service_schemas: A dict mapping service name to the schema for the service.
    in_domain_services: The set of services which are present in the training
      set.

  Returns:
    A dict mapping a metric collection name to a dict containing the values
    for various metrics. Each metric collection aggregates the metrics across
    a specific set of frames in the dialogues.
  """
    # Metrics can be aggregated in various ways, eg over all dialogues, only for
    # dialogues containing unseen services or for dialogues corresponding to a
    # single service. This aggregation is done through metric_collections, which
    # is a dict mapping a collection name to a dict, which maps a metric to a list
    # of values for that metric. Each value in this list is the value taken by
    # the metric on a frame.
    metric_collections = collections.defaultdict(
        lambda: collections.defaultdict(list))

    # Ensure the dialogs in dataset_hyp also occur in dataset_ref.
    assert set(dataset_hyp.keys()).issubset(set(dataset_ref.keys()))
    tf.logging.info("len(dataset_hyp)=%d, len(dataset_ref)=%d",
                    len(dataset_hyp), len(dataset_ref))

    # Store metrics for every frame for debugging.
    per_frame_metric = {}
    for dial_id, dial_hyp in dataset_hyp.items():
        dial_ref = dataset_ref[dial_id]

        if set(dial_ref["services"]) != set(dial_hyp["services"]):
            raise ValueError(
                "Set of services present in ground truth and predictions don't match "
                "for dialogue with id {}".format(dial_id))

        for turn_id, (turn_ref, turn_hyp) in enumerate(
                zip(dial_ref["turns"], dial_hyp["turns"])):
            if turn_ref["speaker"] != turn_hyp["speaker"]:
                raise ValueError(
                    "Speakers don't match in dialogue with id {}".format(
                        dial_id))

            # Skip system turns because metrics are only computed for user turns.
            if turn_ref["speaker"] != "USER":
                continue

            if turn_ref["utterance"] != turn_hyp["utterance"]:
                tf.logging.info("Ref utt: %s", turn_ref["utterance"])
                tf.logging.info("Hyp utt: %s", turn_hyp["utterance"])
                raise ValueError(
                    "Utterances don't match for dialogue with id {}".format(
                        dial_id))

            hyp_frames_by_service = {
                frame["service"]: frame
                for frame in turn_hyp["frames"]
            }

            # Calculate metrics for each frame in each user turn.
            for frame_ref in turn_ref["frames"]:
                service_name = frame_ref["service"]
                if service_name not in hyp_frames_by_service:
                    raise ValueError(
                        "Frame for service {} not found in dialogue with id {}"
                        .format(service_name, dial_id))
                service = service_schemas[service_name]
                frame_hyp = hyp_frames_by_service[service_name]

                active_intent_acc = metrics.get_active_intent_accuracy(
                    frame_ref, frame_hyp)
                slot_tagging_f1_scores = metrics.get_slot_tagging_f1(
                    frame_ref, frame_hyp, turn_ref["utterance"], service)
                requested_slots_f1_scores = metrics.get_requested_slots_f1(
                    frame_ref, frame_hyp)
                goal_accuracy_dict = metrics.get_average_and_joint_goal_accuracy(
                    frame_ref, frame_hyp, service)

                frame_metric = {
                    metrics.ACTIVE_INTENT_ACCURACY:
                    active_intent_acc,
                    metrics.REQUESTED_SLOTS_F1:
                    requested_slots_f1_scores.f1,
                    metrics.REQUESTED_SLOTS_PRECISION:
                    requested_slots_f1_scores.precision,
                    metrics.REQUESTED_SLOTS_RECALL:
                    requested_slots_f1_scores.recall
                }
                if slot_tagging_f1_scores is not None:
                    frame_metric[
                        metrics.SLOT_TAGGING_F1] = slot_tagging_f1_scores.f1
                    frame_metric[metrics.SLOT_TAGGING_PRECISION] = (
                        slot_tagging_f1_scores.precision)
                    frame_metric[
                        metrics.
                        SLOT_TAGGING_RECALL] = slot_tagging_f1_scores.recall
                frame_metric.update(goal_accuracy_dict)

                frame_id = "{:s}-{:03d}-{:s}".format(dial_id, turn_id,
                                                     frame_hyp["service"])
                per_frame_metric[frame_id] = frame_metric
                # Add the frame-level metric result back to dialogues.
                frame_hyp["metrics"] = frame_metric

                # Get the domain name of the service.
                domain_name = frame_hyp["service"].split("_")[0]
                domain_keys = [ALL_SERVICES, frame_hyp["service"], domain_name]
                if frame_hyp["service"] in in_domain_services:
                    domain_keys.append(SEEN_SERVICES)
                else:
                    domain_keys.append(UNSEEN_SERVICES)
                for domain_key in domain_keys:
                    for metric_key, metric_value in frame_metric.items():
                        if metric_value != metrics.NAN_VAL:
                            metric_collections[domain_key][metric_key].append(
                                metric_value)

    all_metric_aggregate = {}
    for domain_key, domain_metric_vals in metric_collections.items():
        domain_metric_aggregate = {}
        for metric_key, value_list in domain_metric_vals.items():
            if value_list:
                # Metrics are macro-averaged across all frames.
                domain_metric_aggregate[metric_key] = float(
                    np.mean(value_list))
            else:
                domain_metric_aggregate[metric_key] = metrics.NAN_VAL
        all_metric_aggregate[domain_key] = domain_metric_aggregate
    return all_metric_aggregate, per_frame_metric
Esempio n. 3
0
def get_metrics(dataset_ref, dataset_hyp, service_schemas, in_domain_services):
    """Calculate the DSTC8 metrics.

    Args:
      dataset_ref: The ground truth dataset represented as a dict mapping dialogue
        id to the corresponding dialogue.
      dataset_hyp: The predictions in the same format as `dataset_ref`.
      service_schemas: A dict mapping service name to the schema for the service.
      in_domain_services: The set of services which are present in the training
        set.

    Returns:
      A dict mapping a metric collection name to a dict containing the values
      for various metrics. Each metric collection aggregates the metrics across
      a specific set of frames in the dialogues.
    """
    # Metrics can be aggregated in various ways, eg over all dialogues, only for
    # dialogues containing unseen services or for dialogues corresponding to a
    # single service. This aggregation is done through metric_collections, which
    # is a dict mapping a collection name to a dict, which maps a metric to a list
    # of values for that metric. Each value in this list is the value taken by
    # the metric on a frame.
    metric_collections = collections.defaultdict(lambda: collections.defaultdict(list))

    # Ensure the dialogs in dataset_hyp also occur in dataset_ref.
    assert set(dataset_hyp.keys()).issubset(set(dataset_ref.keys()))
    tf.logging.info(
        "len(dataset_hyp)=%d, len(dataset_ref)=%d", len(dataset_hyp), len(dataset_ref)
    )

    slot_acc = {}
    for service_name, schema in service_schemas.items():
        for slot in schema["slots"]:
            slot_name = slot["name"]
            slot_acc[slot_name + "_TP"] = 0
            slot_acc[slot_name + "_TN"] = 0
            slot_acc[slot_name + "_FP"] = 0
            slot_acc[slot_name + "_FN"] = 0

    # Store metrics for every frame for debugging.
    per_frame_metric = {}
    for dial_id, dial_hyp in tqdm(dataset_hyp.items()):
        dial_ref = dataset_ref[dial_id]

        if set(dial_ref["services"]) != set(dial_hyp["services"]):
            raise ValueError(
                "Set of services present in ground truth and predictions don't match "
                "for dialogue with id {}".format(dial_id)
            )
        joint_metrics = [
            metrics.JOINT_GOAL_ACCURACY,
            metrics.JOINT_CAT_ACCURACY,
            metrics.JOINT_NONCAT_ACCURACY,
            metrics.JOINT_MAP_GOAL_ACCURACY,
            metrics.JOINT_MAP_CAT_ACCURACY,
            metrics.JOINT_MAP_NONCAT_ACCURACY,
            metrics.JOINT_NONMAP_GOAL_ACCURACY,
            metrics.JOINT_NONMAP_CAT_ACCURACY,
            metrics.JOINT_NONMAP_NONCAT_ACCURACY,
        ]
        for turn_id, (turn_ref, turn_hyp) in enumerate(
            zip(dial_ref["turns"], dial_hyp["turns"])
        ):
            metric_collections_per_turn = collections.defaultdict(
                lambda: collections.defaultdict(lambda: 1.0)
            )
            if turn_ref["speaker"] != turn_hyp["speaker"]:
                raise ValueError(
                    "Speakers don't match in dialogue with id {}".format(dial_id)
                )

            # Skip system turns because metrics are only computed for user turns.
            if turn_ref["speaker"] != "USER":
                continue

            if turn_ref["utterance"] != turn_hyp["utterance"]:
                tf.logging.info("Ref utt: %s", turn_ref["utterance"])
                tf.logging.info("Hyp utt: %s", turn_hyp["utterance"])
                raise ValueError(
                    "Utterances don't match for dialogue with id {}".format(dial_id)
                )

            hyp_frames_by_service = {
                frame["service"]: frame for frame in turn_hyp["frames"]
            }

            # Calculate metrics for each frame in each user turn.
            for frame_ref in turn_ref["frames"]:
                service_name = frame_ref["service"]
                if service_name not in hyp_frames_by_service:
                    frame_hyp = prev_hyp_frames_by_service[service_name]
                    # raise ValueError(
                    #     "Frame for service {} not found in dialogue with id {}".format(
                    #         service_name, dial_id
                    #     )
                    # )
                else:
                    frame_hyp = hyp_frames_by_service[service_name]
                service = service_schemas[service_name]

                active_intent_acc = metrics.get_active_intent_accuracy(
                    frame_ref, frame_hyp
                )
                slot_tagging_f1_scores = metrics.get_slot_tagging_f1(
                    frame_ref, frame_hyp, turn_ref["utterance"], service
                )
                requested_slots_f1_scores = metrics.get_requested_slots_f1(
                    frame_ref, frame_hyp
                )
                (
                    goal_accuracy_dict,
                    slot_acc,
                ) = metrics.get_average_and_joint_goal_accuracy(
                    frame_ref, frame_hyp, service, FLAGS.use_fuzzy_match, slot_acc
                )

                frame_metric = {
                    metrics.ACTIVE_INTENT_ACCURACY: active_intent_acc,
                    metrics.REQUESTED_SLOTS_F1: requested_slots_f1_scores.f1,
                    metrics.REQUESTED_SLOTS_PRECISION: requested_slots_f1_scores.precision,
                    metrics.REQUESTED_SLOTS_RECALL: requested_slots_f1_scores.recall,
                }
                if slot_tagging_f1_scores is not None:
                    frame_metric[metrics.SLOT_TAGGING_F1] = slot_tagging_f1_scores.f1
                    frame_metric[
                        metrics.SLOT_TAGGING_PRECISION
                    ] = slot_tagging_f1_scores.precision
                    frame_metric[
                        metrics.SLOT_TAGGING_RECALL
                    ] = slot_tagging_f1_scores.recall
                frame_metric.update(goal_accuracy_dict)

                frame_id = "{:s}-{:03d}-{:s}".format(
                    dial_id, turn_id, frame_hyp["service"]
                )
                per_frame_metric[frame_id] = frame_metric
                # Add the frame-level metric result back to dialogues.
                frame_hyp["metrics"] = frame_metric

                # Get the domain name of the service.
                domain_name = frame_hyp["service"].split("_")[0]
                domain_keys = [ALL_SERVICES, frame_hyp["service"], domain_name]
                if frame_hyp["service"] in in_domain_services:
                    domain_keys.append(SEEN_SERVICES)
                else:
                    domain_keys.append(UNSEEN_SERVICES)
                for domain_key in domain_keys:
                    for metric_key, metric_value in frame_metric.items():
                        if metric_value != metrics.NAN_VAL:
                            if (
                                FLAGS.joint_acc_across_turn
                                and metric_key in joint_metrics
                            ):
                                metric_collections_per_turn[domain_key][
                                    metric_key
                                ] *= metric_value
                            else:
                                metric_collections[domain_key][metric_key].append(
                                    metric_value
                                )
            if FLAGS.joint_acc_across_turn:
                # Conduct multiwoz style evaluation that computes joint goal accuracy
                # across all the slot values of all the domains for each turn.
                for domain_key in metric_collections_per_turn:
                    for metric_key, metric_value in metric_collections_per_turn[
                        domain_key
                    ].items():
                        metric_collections[domain_key][metric_key].append(metric_value)
            prev_hyp_frames_by_service = hyp_frames_by_service
    slot_acc_dict = {}
    df = []
    for service_name, schema in service_schemas.items():
        for slot in schema["slots"]:
            slot_name = slot["name"]
            if slot_name not in ORDERED_TRACK_SLOTS:
                continue
            TP = slot_acc[slot_name + "_TP"]
            FP = slot_acc[slot_name + "_FP"]
            TN = slot_acc[slot_name + "_TN"]
            FN = slot_acc[slot_name + "_FN"]
            if (TP + FP) == 0 or (TP + FN) == 0 or TP == 0:
                precision = 0
                recall = 0
                f1 = 0
            else:
                precision = TP / (TP + FP)
                recall = TP / (TP + FN)
                f1 = 2 * precision * recall / (precision + recall)
            acc = (TP + TN) / (TP + TN + FP + FN) if TP + TN + FP + FN else 0

            slot_acc_dict[slot_name + "_precision"] = precision
            slot_acc_dict[slot_name + "_recall"] = recall
            slot_acc_dict[slot_name + "_f1"] = f1
            slot_acc_dict[slot_name + "_acc"] = acc
            df.append(
                [
                    service_name,
                    slot_name,
                    "Mapping" if slot_name in MAPPING_SLOT else "Non-mapping",
                    "Categorical" if slot["is_categorical"] else "Non-categorical",
                    acc,
                    f1,
                    precision,
                    recall,
                ]
            )
    df = pd.DataFrame(
        df,
        columns=[
            "Service",
            "Slot",
            "Mapping",
            "Categorical",
            "Accuracy",
            "F1",
            "Recall",
            "Precision",
        ],
    )

    all_metric_aggregate = {}
    for domain_key, domain_metric_vals in metric_collections.items():
        domain_metric_aggregate = {}
        for metric_key, value_list in domain_metric_vals.items():
            if value_list:
                # Metrics are macro-averaged across all frames.
                domain_metric_aggregate[metric_key] = float(np.mean(value_list))
            else:
                domain_metric_aggregate[metric_key] = metrics.NAN_VAL
        all_metric_aggregate[domain_key] = domain_metric_aggregate
    return all_metric_aggregate, per_frame_metric, slot_acc_dict, df