def _preprocess_fn(data): """The preprocessing function that is returned.""" # Validate input if not isinstance(data, dict): raise ValueError("Argument `data` must be a dictionary, " "not %s" % str(type(data))) # Apply all the individual steps in sequence. tf.logging.info("Data before pre-processing:\n%s", data) for lookup_string in pp_pipeline.split("|"): # These calls are purely positional, so no need for kwargs. name, args, _ = registry.parse_name_and_kwargs(lookup_string) cls = ops.get(name) data = cls.apply(*args)(data) if remove_tpu_dtypes: # Remove data that are TPU-incompatible (e.g. filename of type tf.string). for key in list(data.keys()): if data[key].dtype not in TPU_SUPPORTED_DTYPES: tf.logging.warning( "Removing key '{}' from data dict because its dtype {} is not in " " the supported dtypes: {}".format(key, data[key].dtype, TPU_SUPPORTED_DTYPES)) data = get_delete_field(key)(data) tf.logging.info("Data after pre-processing:\n%s", data) return data
def add_measurement(self, dataset_spec, metric_name, metric_results): dataset_name, _, dataset_kwargs = registry.parse_name_and_kwargs( dataset_spec) if dataset_name == "imagenet_c": if metric_name == "timing": value = metric_results["mean"] else: # All remaining metrics are a dictionary with a single value, with # the key equal to the metric name. value, = list(metric_results.values()) self._corruption_metrics[f"imagenet_c/{metric_name}"].append(value) corruption_type = dataset_kwargs["corruption_type"] if metric_name == "accuracy": self._accuracy_per_corruption[corruption_type].append(value) elif dataset_name == "imagenet_v2": variant = dataset_kwargs["variant"] if variant == "MATCHED_FREQUENCY": if metric_name == "timing": value = metric_results["mean"] else: # All remaining metrics are a dictionary with a single value, with # the key equal to the metric name. value, = list(metric_results.values()) self._results[f"imagenet_v2/{metric_name}"] = value else: logging.info("Ignoring v2 variant %r", variant) else: if metric_name == "timing": value = metric_results["mean"] else: # All remaining metrics produce a dictionary with a single value, with # the key equal to the metric name. value, = list(metric_results.values()) self._results[f"{dataset_name}/{metric_name}"] = value
def _get_full_metric_key(self, dataset_name, metric_name, metric_key): _, _, diversity_metric_kwargs = registry.parse_name_and_kwargs( metric_name) is_normalized = diversity_metric_kwargs["normalize_disagreement"] if metric_key == "disagreement" and is_normalized: full_metric_key = f"{dataset_name}/normalized_{metric_key}" else: full_metric_key = f"{dataset_name}/{metric_key}" return full_metric_key
def add_measurement(self, dataset_spec, metric_name, metric_results): if metric_name not in [ "average_pairwise_diversity(normalize_disagreement=True)", "average_pairwise_diversity(normalize_disagreement=False)" ]: super().add_measurement(dataset_spec, metric_name, metric_results) else: dataset_name, _, _ = registry.parse_name_and_kwargs(dataset_spec) for metric_key, metric_value in metric_results.items(): key = self._get_full_metric_key(dataset_name, metric_name, metric_key) if dataset_name == "cifar10_c": self._corruption_metrics[key].append(metric_value) else: self._results[key] = metric_value
def add_measurement(self, dataset_spec, metric_name, metric_results): dataset_name, _, _ = registry.parse_name_and_kwargs(dataset_spec) if dataset_name == "cifar10_c": if metric_name == "timing": value = metric_results["mean"] else: # All remaining metrics are a dictionary with a single value, with # the key equal to the metric name. value, = list(metric_results.values()) self._corruption_metrics[f"cifar10_c/{metric_name}"].append(value) else: if metric_name == "timing": value = metric_results["mean"] else: # All remaining metrics are a dictionary with a single value, with # the key equal to the metric name. value, = list(metric_results.values()) self._results[f"{dataset_name}/{metric_name}"] = value
def _preprocess_fn(data): """The preprocessing function that is returned.""" # Validate input if not isinstance(data, dict): raise ValueError("Argument `data` must be a dictionary, " "not %s" % str(type(data))) # Apply all the individual steps in sequence. tf.logging.info("Data before pre-processing:\n%s", data) for lookup_string in pp_pipeline.split("|"): # These calls are purely positional, so no need for kwargs. name, args, _ = registry.parse_name_and_kwargs(lookup_string) cls = ops.get(name) data = cls.apply(*args)(data) if remove_tpu_dtypes: data = keep_only_tpu_types(data) tf.logging.info("Data after pre-processing:\n%s", data) return data
def add_measurement(self, dataset_spec, metric_name, metric_results): if metric_name not in [ "average_pairwise_diversity(normalize_disagreement=True)", "average_pairwise_diversity(normalize_disagreement=False)" ]: super().add_measurement(dataset_spec, metric_name, metric_results) else: dataset_name, _, dataset_kwargs = registry.parse_name_and_kwargs( dataset_spec) if dataset_name == "imagenet_v2": if dataset_kwargs["variant"] != "MATCHED_FREQUENCY": return for metric_key, metric_value in metric_results.items(): key = self._get_full_metric_key(dataset_name, metric_name, metric_key) if dataset_name == "imagenet_c": self._corruption_metrics[key].append(metric_value) else: self._results[key] = metric_value