コード例 #1
0
  def _preprocess_fn(data):
    """The preprocessing function that is returned."""

    # Validate input
    if not isinstance(data, dict):
      raise ValueError("Argument `data` must be a dictionary, "
                       "not %s" % str(type(data)))

    # Apply all the individual steps in sequence.
    tf.logging.info("Data before pre-processing:\n%s", data)
    for lookup_string in pp_pipeline.split("|"):
      # These calls are purely positional, so no need for kwargs.
      name, args, _ = registry.parse_name_and_kwargs(lookup_string)
      cls = ops.get(name)
      data = cls.apply(*args)(data)

    if remove_tpu_dtypes:
      # Remove data that are TPU-incompatible (e.g. filename of type tf.string).
      for key in list(data.keys()):
        if data[key].dtype not in TPU_SUPPORTED_DTYPES:
          tf.logging.warning(
              "Removing key '{}' from data dict because its dtype {} is not in "
              " the supported dtypes: {}".format(key, data[key].dtype,
                                                 TPU_SUPPORTED_DTYPES))
          data = get_delete_field(key)(data)
    tf.logging.info("Data after pre-processing:\n%s", data)
    return data
コード例 #2
0
 def add_measurement(self, dataset_spec, metric_name, metric_results):
   dataset_name, _, dataset_kwargs = registry.parse_name_and_kwargs(
       dataset_spec)
   if dataset_name == "imagenet_c":
     if metric_name == "timing":
       value = metric_results["mean"]
     else:
       # All remaining metrics are a dictionary with a single value, with
       # the key equal to the metric name.
       value, = list(metric_results.values())
     self._corruption_metrics[f"imagenet_c/{metric_name}"].append(value)
     corruption_type = dataset_kwargs["corruption_type"]
     if metric_name == "accuracy":
       self._accuracy_per_corruption[corruption_type].append(value)
   elif dataset_name == "imagenet_v2":
     variant = dataset_kwargs["variant"]
     if variant == "MATCHED_FREQUENCY":
       if metric_name == "timing":
         value = metric_results["mean"]
       else:
         # All remaining metrics are a dictionary with a single value, with
         # the key equal to the metric name.
         value, = list(metric_results.values())
       self._results[f"imagenet_v2/{metric_name}"] = value
     else:
       logging.info("Ignoring v2 variant %r", variant)
   else:
     if metric_name == "timing":
       value = metric_results["mean"]
     else:
       # All remaining metrics produce a dictionary with a single value, with
       # the key equal to the metric name.
       value, = list(metric_results.values())
     self._results[f"{dataset_name}/{metric_name}"] = value
コード例 #3
0
 def _get_full_metric_key(self, dataset_name, metric_name, metric_key):
     _, _, diversity_metric_kwargs = registry.parse_name_and_kwargs(
         metric_name)
     is_normalized = diversity_metric_kwargs["normalize_disagreement"]
     if metric_key == "disagreement" and is_normalized:
         full_metric_key = f"{dataset_name}/normalized_{metric_key}"
     else:
         full_metric_key = f"{dataset_name}/{metric_key}"
     return full_metric_key
コード例 #4
0
 def add_measurement(self, dataset_spec, metric_name, metric_results):
     if metric_name not in [
             "average_pairwise_diversity(normalize_disagreement=True)",
             "average_pairwise_diversity(normalize_disagreement=False)"
     ]:
         super().add_measurement(dataset_spec, metric_name, metric_results)
     else:
         dataset_name, _, _ = registry.parse_name_and_kwargs(dataset_spec)
         for metric_key, metric_value in metric_results.items():
             key = self._get_full_metric_key(dataset_name, metric_name,
                                             metric_key)
             if dataset_name == "cifar10_c":
                 self._corruption_metrics[key].append(metric_value)
             else:
                 self._results[key] = metric_value
コード例 #5
0
 def add_measurement(self, dataset_spec, metric_name, metric_results):
     dataset_name, _, _ = registry.parse_name_and_kwargs(dataset_spec)
     if dataset_name == "cifar10_c":
         if metric_name == "timing":
             value = metric_results["mean"]
         else:
             # All remaining metrics are a dictionary with a single value, with
             # the key equal to the metric name.
             value, = list(metric_results.values())
         self._corruption_metrics[f"cifar10_c/{metric_name}"].append(value)
     else:
         if metric_name == "timing":
             value = metric_results["mean"]
         else:
             # All remaining metrics are a dictionary with a single value, with
             # the key equal to the metric name.
             value, = list(metric_results.values())
         self._results[f"{dataset_name}/{metric_name}"] = value
コード例 #6
0
  def _preprocess_fn(data):
    """The preprocessing function that is returned."""

    # Validate input
    if not isinstance(data, dict):
      raise ValueError("Argument `data` must be a dictionary, "
                       "not %s" % str(type(data)))

    # Apply all the individual steps in sequence.
    tf.logging.info("Data before pre-processing:\n%s", data)
    for lookup_string in pp_pipeline.split("|"):
      # These calls are purely positional, so no need for kwargs.
      name, args, _ = registry.parse_name_and_kwargs(lookup_string)
      cls = ops.get(name)
      data = cls.apply(*args)(data)

    if remove_tpu_dtypes:
      data = keep_only_tpu_types(data)
    tf.logging.info("Data after pre-processing:\n%s", data)
    return data
コード例 #7
0
    def add_measurement(self, dataset_spec, metric_name, metric_results):
        if metric_name not in [
                "average_pairwise_diversity(normalize_disagreement=True)",
                "average_pairwise_diversity(normalize_disagreement=False)"
        ]:
            super().add_measurement(dataset_spec, metric_name, metric_results)
        else:
            dataset_name, _, dataset_kwargs = registry.parse_name_and_kwargs(
                dataset_spec)

            if dataset_name == "imagenet_v2":
                if dataset_kwargs["variant"] != "MATCHED_FREQUENCY":
                    return

            for metric_key, metric_value in metric_results.items():
                key = self._get_full_metric_key(dataset_name, metric_name,
                                                metric_key)
                if dataset_name == "imagenet_c":
                    self._corruption_metrics[key].append(metric_value)
                else:
                    self._results[key] = metric_value