Пример #1
0
    def _get_metric_object(self, metric, y_t, y_p):
        """Converts user-supplied metric to a `Metric` object.

    Args:
      metric: A string, function, or `Metric` object.
      y_t: Sample of label.
      y_p: Sample of output.

    Returns:
      A `Metric` object.
    """
        if metric is None:
            return None  # Ok to have no metric for an output.

        # Convenience feature for selecting b/t binary, categorical,
        # and sparse categorical.
        if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
            metric_obj = metrics_mod.get(metric)
        else:
            y_t_rank = len(y_t.shape.as_list())
            y_p_rank = len(y_p.shape.as_list())
            y_t_last_dim = y_t.shape.as_list()[-1]
            y_p_last_dim = y_p.shape.as_list()[-1]

            is_binary = y_p_last_dim == 1
            is_sparse_categorical = (y_t_rank < y_p_rank
                                     or y_t_last_dim == 1 and y_p_last_dim > 1)

            if metric in ['accuracy', 'acc']:
                if is_binary:
                    metric_obj = metrics_mod.binary_accuracy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_accuracy
                else:
                    metric_obj = metrics_mod.categorical_accuracy
            else:
                if is_binary:
                    metric_obj = metrics_mod.binary_crossentropy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_crossentropy
                else:
                    metric_obj = metrics_mod.categorical_crossentropy

        if isinstance(metric_obj, losses_mod.Loss):
            metric_obj._allow_sum_over_batch_size = True  # pylint: disable=protected-access

        if not isinstance(metric_obj, metrics_mod.Metric):
            if isinstance(metric, six.string_types):
                metric_name = metric
            else:
                metric_name = get_custom_object_name(metric)
                if metric_name is None:
                    raise ValueError(
                        'Metric should be a callable, found: {}'.format(
                            metric))

            metric_obj = metrics_mod.MeanMetricWrapper(metric_obj,
                                                       name=metric_name)

        return metric_obj
Пример #2
0
    def test_metric_wrappers_autograph(self):
        def metric_fn(y_true, y_pred):
            x = tf.constant(0.0)
            for i in range(len(y_true)):
                for j in range(len(y_true[i])):
                    if (tf.equal(y_true[i][j], y_pred[i][j])
                            and y_true[i][j] > 0):
                        x += 1.0
            return x

        mean_metric = metrics.MeanMetricWrapper(metric_fn)
        sum_metric = metrics.SumOverBatchSizeMetricWrapper(metric_fn)
        self.evaluate(tf.compat.v1.variables_initializer(
            mean_metric.variables))
        self.evaluate(tf.compat.v1.variables_initializer(sum_metric.variables))

        y_true = tf.constant([[0, 0, 0, 1, 0], [0, 0, 1, 1, 1],
                              [1, 1, 1, 1, 0], [1, 1, 1, 0, 1]])
        y_pred = tf.constant([[0, 0, 1, 1, 0], [1, 1, 1, 1, 1],
                              [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]])

        @tf.function
        def tf_functioned_metric_fn(metric, y_true, y_pred):
            return metric(y_true, y_pred)

        metric_result = tf_functioned_metric_fn(mean_metric, y_true, y_pred)
        self.assertAllClose(self.evaluate(metric_result), 10, 1e-2)
        metric_result = tf_functioned_metric_fn(sum_metric, y_true, y_pred)
        self.assertAllClose(self.evaluate(metric_result), 10, 1e-2)
Пример #3
0
    def _get_metric_object(self, metric, y_t, y_p):
        """Converts user-supplied metric to a `Metric` object.

        Args:
          metric: A string, function, or `Metric` object.
          y_t: Sample of label.
          y_p: Sample of output.

        Returns:
          A `Metric` object.
        """
        if metric is None:
            return None  # Ok to have no metric for an output.

        # Convenience feature for selecting b/t binary, categorical,
        # and sparse categorical.
        if str(metric).lower() not in [
                "accuracy", "acc", "crossentropy", "ce"
        ]:
            metric_obj = metrics_mod.get(metric)
        else:
            y_t_rank = len(y_t.shape.as_list())
            y_p_rank = len(y_p.shape.as_list())
            y_t_last_dim = y_t.shape.as_list()[-1]
            y_p_last_dim = y_p.shape.as_list()[-1]

            is_binary = y_p_last_dim == 1
            is_sparse_categorical = (y_t_rank < y_p_rank
                                     or y_t_last_dim == 1 and y_p_last_dim > 1)

            if str(metric).lower() in ["accuracy", "acc"]:
                if is_binary:
                    metric_obj = metrics_mod.binary_accuracy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_accuracy
                else:
                    metric_obj = metrics_mod.categorical_accuracy
            else:
                if is_binary:
                    metric_obj = metrics_mod.binary_crossentropy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_crossentropy
                else:
                    metric_obj = metrics_mod.categorical_crossentropy

        if isinstance(metric_obj, losses_mod.Loss):
            metric_obj._allow_sum_over_batch_size = True

        if not isinstance(metric_obj, metrics_mod.Metric):
            if isinstance(metric, str):
                metric_name = metric
            else:
                metric_name = get_custom_object_name(metric)
                if metric_name is None:
                    raise ValueError(
                        f"Metric should be a callable, received: {metric}")

            metric_obj = metrics_mod.MeanMetricWrapper(metric_obj,
                                                       name=metric_name)

        return metric_obj