def _get_metric_object(self, metric, y_t, y_p): """Converts user-supplied metric to a `Metric` object. Args: metric: A string, function, or `Metric` object. y_t: Sample of label. y_p: Sample of output. Returns: A `Metric` object. """ if metric is None: return None # Ok to have no metric for an output. # Convenience feature for selecting b/t binary, categorical, # and sparse categorical. if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: metric_obj = metrics_mod.get(metric) else: y_t_rank = len(y_t.shape.as_list()) y_p_rank = len(y_p.shape.as_list()) y_t_last_dim = y_t.shape.as_list()[-1] y_p_last_dim = y_p.shape.as_list()[-1] is_binary = y_p_last_dim == 1 is_sparse_categorical = (y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1) if metric in ['accuracy', 'acc']: if is_binary: metric_obj = metrics_mod.binary_accuracy elif is_sparse_categorical: metric_obj = metrics_mod.sparse_categorical_accuracy else: metric_obj = metrics_mod.categorical_accuracy else: if is_binary: metric_obj = metrics_mod.binary_crossentropy elif is_sparse_categorical: metric_obj = metrics_mod.sparse_categorical_crossentropy else: metric_obj = metrics_mod.categorical_crossentropy if isinstance(metric_obj, losses_mod.Loss): metric_obj._allow_sum_over_batch_size = True # pylint: disable=protected-access if not isinstance(metric_obj, metrics_mod.Metric): if isinstance(metric, six.string_types): metric_name = metric else: metric_name = get_custom_object_name(metric) if metric_name is None: raise ValueError( 'Metric should be a callable, found: {}'.format( metric)) metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name) return metric_obj
def test_metric_wrappers_autograph(self): def metric_fn(y_true, y_pred): x = tf.constant(0.0) for i in range(len(y_true)): for j in range(len(y_true[i])): if (tf.equal(y_true[i][j], y_pred[i][j]) and y_true[i][j] > 0): x += 1.0 return x mean_metric = metrics.MeanMetricWrapper(metric_fn) sum_metric = metrics.SumOverBatchSizeMetricWrapper(metric_fn) self.evaluate(tf.compat.v1.variables_initializer( mean_metric.variables)) self.evaluate(tf.compat.v1.variables_initializer(sum_metric.variables)) y_true = tf.constant([[0, 0, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [1, 1, 1, 0, 1]]) y_pred = tf.constant([[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]) @tf.function def tf_functioned_metric_fn(metric, y_true, y_pred): return metric(y_true, y_pred) metric_result = tf_functioned_metric_fn(mean_metric, y_true, y_pred) self.assertAllClose(self.evaluate(metric_result), 10, 1e-2) metric_result = tf_functioned_metric_fn(sum_metric, y_true, y_pred) self.assertAllClose(self.evaluate(metric_result), 10, 1e-2)
def _get_metric_object(self, metric, y_t, y_p): """Converts user-supplied metric to a `Metric` object. Args: metric: A string, function, or `Metric` object. y_t: Sample of label. y_p: Sample of output. Returns: A `Metric` object. """ if metric is None: return None # Ok to have no metric for an output. # Convenience feature for selecting b/t binary, categorical, # and sparse categorical. if str(metric).lower() not in [ "accuracy", "acc", "crossentropy", "ce" ]: metric_obj = metrics_mod.get(metric) else: y_t_rank = len(y_t.shape.as_list()) y_p_rank = len(y_p.shape.as_list()) y_t_last_dim = y_t.shape.as_list()[-1] y_p_last_dim = y_p.shape.as_list()[-1] is_binary = y_p_last_dim == 1 is_sparse_categorical = (y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1) if str(metric).lower() in ["accuracy", "acc"]: if is_binary: metric_obj = metrics_mod.binary_accuracy elif is_sparse_categorical: metric_obj = metrics_mod.sparse_categorical_accuracy else: metric_obj = metrics_mod.categorical_accuracy else: if is_binary: metric_obj = metrics_mod.binary_crossentropy elif is_sparse_categorical: metric_obj = metrics_mod.sparse_categorical_crossentropy else: metric_obj = metrics_mod.categorical_crossentropy if isinstance(metric_obj, losses_mod.Loss): metric_obj._allow_sum_over_batch_size = True if not isinstance(metric_obj, metrics_mod.Metric): if isinstance(metric, str): metric_name = metric else: metric_name = get_custom_object_name(metric) if metric_name is None: raise ValueError( f"Metric should be a callable, received: {metric}") metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name) return metric_obj