def get_base_metric_name(metric, weighted=False): """Returns the metric name given the metric function. Arguments: metric: Metric function name or reference. weighted: Boolean indicating if the metric for which we are adding names is weighted. Returns: a metric name. """ metric_name_prefix = 'weighted_' if weighted else '' if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): if metric in ('accuracy', 'acc'): suffix = 'acc' elif metric in ('crossentropy', 'ce'): suffix = 'ce' metric_name = metric_name_prefix + suffix else: metric_fn = metrics_module.get(metric) # Get metric name as string if hasattr(metric_fn, 'name'): metric_name = metric_fn.name else: metric_name = metric_fn.__name__ metric_name = metric_name_prefix + metric_name return metric_name
def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): if metric == 'accuracy' or metric == 'acc': # custom handling of accuracy # (because of class mode duality) output_shape = internal_output_shapes if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy: # case: binary accuracy acc_fn = metrics_module.binary_accuracy elif loss_func == losses.sparse_categorical_crossentropy: # case: categorical accuracy with sparse targets acc_fn = metrics_module.sparse_categorical_accuracy else: acc_fn = metrics_module.categorical_accuracy metric_name = 'acc' return metric_name, acc_fn else: metric_fn = metrics_module.get(metric) metric_name = metric_fn.__name__ return metric_name, metric_fn