Exemplo n.º 1
0
  def test_requested_slots_f1(self):
    # (1) Test on oracle frame.
    requestable_slots_f1_oracle = metrics.get_requested_slots_f1(
        self.frame_ref, self.frame_ref)
    # Ground truth values for oracle prediction are all 1.0.
    self._assert_dicts_almost_equal(
        {k: 1.0 for k in self.correct_metrics[metrics.REQUESTED_SLOTS_F1]},
        requestable_slots_f1_oracle._asdict())

    # (2) Test on a previously known frame.
    requested_slots_f1_hyp = metrics.get_requested_slots_f1(
        self.frame_ref, self.frame_hyp)
    self._assert_dicts_almost_equal(
        self.correct_metrics[metrics.REQUESTED_SLOTS_F1],
        requested_slots_f1_hyp._asdict())
Exemplo n.º 2
0
def get_metrics(dataset_ref, dataset_hyp, service_schemas, in_domain_services):
    """Calculate the DSTC8 metrics.

  Args:
    dataset_ref: The ground truth dataset represented as a dict mapping dialogue
      id to the corresponding dialogue.
    dataset_hyp: The predictions in the same format as `dataset_ref`.
    service_schemas: A dict mapping service name to the schema for the service.
    in_domain_services: The set of services which are present in the training
      set.

  Returns:
    A dict mapping a metric collection name to a dict containing the values
    for various metrics. Each metric collection aggregates the metrics across
    a specific set of frames in the dialogues.
  """
    # Metrics can be aggregated in various ways, eg over all dialogues, only for
    # dialogues containing unseen services or for dialogues corresponding to a
    # single service. This aggregation is done through metric_collections, which
    # is a dict mapping a collection name to a dict, which maps a metric to a list
    # of values for that metric. Each value in this list is the value taken by
    # the metric on a frame.
    metric_collections = collections.defaultdict(
        lambda: collections.defaultdict(list))

    # Ensure the dialogs in dataset_hyp also occur in dataset_ref.
    assert set(dataset_hyp.keys()).issubset(set(dataset_ref.keys()))
    tf.logging.info("len(dataset_hyp)=%d, len(dataset_ref)=%d",
                    len(dataset_hyp), len(dataset_ref))

    # Store metrics for every frame for debugging.
    per_frame_metric = {}
    for dial_id, dial_hyp in dataset_hyp.items():
        dial_ref = dataset_ref[dial_id]

        if set(dial_ref["services"]) != set(dial_hyp["services"]):
            raise ValueError(
                "Set of services present in ground truth and predictions don't match "
                "for dialogue with id {}".format(dial_id))

        for turn_id, (turn_ref, turn_hyp) in enumerate(
                zip(dial_ref["turns"], dial_hyp["turns"])):
            if turn_ref["speaker"] != turn_hyp["speaker"]:
                raise ValueError(
                    "Speakers don't match in dialogue with id {}".format(
                        dial_id))

            # Skip system turns because metrics are only computed for user turns.
            if turn_ref["speaker"] != "USER":
                continue

            if turn_ref["utterance"] != turn_hyp["utterance"]:
                tf.logging.info("Ref utt: %s", turn_ref["utterance"])
                tf.logging.info("Hyp utt: %s", turn_hyp["utterance"])
                raise ValueError(
                    "Utterances don't match for dialogue with id {}".format(
                        dial_id))

            hyp_frames_by_service = {
                frame["service"]: frame
                for frame in turn_hyp["frames"]
            }

            # Calculate metrics for each frame in each user turn.
            for frame_ref in turn_ref["frames"]:
                service_name = frame_ref["service"]
                if service_name not in hyp_frames_by_service:
                    raise ValueError(
                        "Frame for service {} not found in dialogue with id {}"
                        .format(service_name, dial_id))
                service = service_schemas[service_name]
                frame_hyp = hyp_frames_by_service[service_name]

                active_intent_acc = metrics.get_active_intent_accuracy(
                    frame_ref, frame_hyp)
                slot_tagging_f1_scores = metrics.get_slot_tagging_f1(
                    frame_ref, frame_hyp, turn_ref["utterance"], service)
                requested_slots_f1_scores = metrics.get_requested_slots_f1(
                    frame_ref, frame_hyp)
                goal_accuracy_dict = metrics.get_average_and_joint_goal_accuracy(
                    frame_ref, frame_hyp, service)

                frame_metric = {
                    metrics.ACTIVE_INTENT_ACCURACY:
                    active_intent_acc,
                    metrics.REQUESTED_SLOTS_F1:
                    requested_slots_f1_scores.f1,
                    metrics.REQUESTED_SLOTS_PRECISION:
                    requested_slots_f1_scores.precision,
                    metrics.REQUESTED_SLOTS_RECALL:
                    requested_slots_f1_scores.recall
                }
                if slot_tagging_f1_scores is not None:
                    frame_metric[
                        metrics.SLOT_TAGGING_F1] = slot_tagging_f1_scores.f1
                    frame_metric[metrics.SLOT_TAGGING_PRECISION] = (
                        slot_tagging_f1_scores.precision)
                    frame_metric[
                        metrics.
                        SLOT_TAGGING_RECALL] = slot_tagging_f1_scores.recall
                frame_metric.update(goal_accuracy_dict)

                frame_id = "{:s}-{:03d}-{:s}".format(dial_id, turn_id,
                                                     frame_hyp["service"])
                per_frame_metric[frame_id] = frame_metric
                # Add the frame-level metric result back to dialogues.
                frame_hyp["metrics"] = frame_metric

                # Get the domain name of the service.
                domain_name = frame_hyp["service"].split("_")[0]
                domain_keys = [ALL_SERVICES, frame_hyp["service"], domain_name]
                if frame_hyp["service"] in in_domain_services:
                    domain_keys.append(SEEN_SERVICES)
                else:
                    domain_keys.append(UNSEEN_SERVICES)
                for domain_key in domain_keys:
                    for metric_key, metric_value in frame_metric.items():
                        if metric_value != metrics.NAN_VAL:
                            metric_collections[domain_key][metric_key].append(
                                metric_value)

    all_metric_aggregate = {}
    for domain_key, domain_metric_vals in metric_collections.items():
        domain_metric_aggregate = {}
        for metric_key, value_list in domain_metric_vals.items():
            if value_list:
                # Metrics are macro-averaged across all frames.
                domain_metric_aggregate[metric_key] = float(
                    np.mean(value_list))
            else:
                domain_metric_aggregate[metric_key] = metrics.NAN_VAL
        all_metric_aggregate[domain_key] = domain_metric_aggregate
    return all_metric_aggregate, per_frame_metric