def get_metric_name_value_pairs(metric: Metric, default_name: str, reset: bool = False) -> Iterable[Tuple[str, float]]: """ Return the metric as in `Metric.get_metric` but as an iterable of string-float pairs. """ value = metric.get_metric(reset) if isinstance(value, collections.abc.Mapping): for sub_name, sub_value in value.items(): if isinstance(sub_value, collections.abc.Iterable): for i, sub_value_i in enumerate(sub_value): yield f"{sub_name}_{i}", sub_value_i else: yield sub_name, sub_value elif isinstance(value, collections.abc.Iterable): for i, sub_value in enumerate(value): # type: ignore yield f"{default_name}_{i}", sub_value # type: ignore else: yield default_name, value
def global_distributed_metric( global_rank: int, world_size: int, gpu_id: Union[int, torch.device], metric: Metric, metric_kwargs: Dict[str, List[Any]], desired_values: Dict[str, Any], exact: Union[bool, Tuple[float, float]] = True, number_of_runs: int = 1, ): kwargs = {} # Use the arguments meant for the process with rank `global_rank`. for argname in metric_kwargs: kwargs[argname] = metric_kwargs[argname][global_rank] for _ in range(number_of_runs): metric(**kwargs) metrics = metric.get_metric(False) if not isinstance(metrics, Dict) and not isinstance(desired_values, Dict): metrics = {"metric_value": metrics} desired_values = {"metric_value": desired_values} # Call `assertion_metrics_values` to check if the metrics have the desired values. if isinstance(exact, bool): if exact: rtol = 0.0 atol = 0.0 else: rtol = 0.0001 atol = 1e-05 else: rtol = exact[0] atol = exact[1] assert_metrics_values(metrics, desired_values, rtol, atol) # type: ignore