def test_one_dimensional(self):
    x = constant_op.constant([.3, .1, .2, -.5, 42.])
    top_1 = self.evaluate(metrics_utils._filter_top_k(x=x, k=1))
    top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))
    top_3 = self.evaluate(metrics_utils._filter_top_k(x=x, k=3))

    self.assertAllClose(top_1, [
        metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF,
        metrics_utils.NEG_INF, 42.
    ])
    self.assertAllClose(top_2, [
        .3, metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF,
        42.
    ])
    self.assertAllClose(
        top_3, [.3, metrics_utils.NEG_INF, .2, metrics_utils.NEG_INF, 42.])
  def test_three_dimensional(self):
    x = constant_op.constant([[[.3, .1, .2], [-.3, -.2, -.1]],
                              [[5., .2, 42.], [-.3, -.6, -.99]]])
    top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))

    self.assertAllClose(
        top_2,
        [[[.3, metrics_utils.NEG_INF, .2], [metrics_utils.NEG_INF, -.2, -.1]],
         [[5., metrics_utils.NEG_INF, 42.], [-.3, -.6, metrics_utils.NEG_INF]]])
    def _filter_top_k(x):
      # This loses the static shape.
      x = script_ops.numpy_function(_identity, (x,), dtypes.float32)

      return metrics_utils._filter_top_k(x=x, k=2)
Exemplo n.º 4
0
    def update_state(self, y_true, y_pred, sample_weight=None):
        # Cast inputs
        y_pred = tf.convert_to_tensor(y_pred)
        y_true = tf.cast(y_true, dtype=y_pred.dtype)

        # Transform inputs
        [y_pred, y_true
         ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values(
             [y_pred, y_true], sample_weight)

        # Get threshold properties
        if isinstance(self.thresholds, list):
            num_thresholds = len(self.thresholds)
        else:
            num_thresholds = len(list(self.thresholds))

        # Check input values and adjust shapes
        with ops.control_dependencies([
                check_ops.assert_greater_equal(
                    y_pred,
                    tf.cast(0.0, dtype=y_pred.dtype),
                    message='predictions must be >= 0'),
                check_ops.assert_less_equal(y_pred,
                                            tf.cast(1.0, dtype=y_pred.dtype),
                                            message='predictions must be <= 1')
        ]):

            if sample_weight is None:
                y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions(
                    y_pred, y_true)
            else:
                y_pred, y_true, sample_weight = (
                    tf_losses_utils.squeeze_or_expand_dimensions(
                        y_pred, y_true, sample_weight=sample_weight))

        # Check shape compatibility
        y_pred.shape.assert_is_compatible_with(y_true.shape)

        # Check if num_classes corresponds to y_pred
        if self.average != 'micro':
            tf.debugging.assert_shapes(
                shapes=[(y_pred, (..., self.num_classes))],
                data=y_pred,
                summarize=10,
                message='num_classes must correspond to the prediction')

        # Filter top k
        if self.top_k is not None:
            y_pred = metrics_utils._filter_top_k(y_pred, self.top_k)

        # Select class id
        if self.class_id is not None:
            y_true = y_true[..., self.class_id]
            y_pred = y_pred[..., self.class_id]

        # Get prediction shape
        pred_shape = tf.shape(y_pred)
        num_predictions = pred_shape[0]

        # Set label shapes
        if y_pred.shape.ndims == 1:
            num_labels = 1
        else:
            num_labels = K.prod(pred_shape[1:], axis=0)

        # Flatten predicitons and labels
        predictions_extra_dim = tf.reshape(y_pred, [1, -1])
        labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1])

        # Tile the thresholds for every prediction
        thresh_pretile_shape = [num_thresholds, -1]
        thresh_tiles = [1, num_predictions * num_labels]
        data_tiles = [num_thresholds, 1]

        thresh_tiled = tf.tile(
            tf.reshape(tf.constant(self.thresholds, dtype=y_pred.dtype),
                       thresh_pretile_shape), tf.stack(thresh_tiles))

        # Tile the predictions for every threshold
        preds_tiled = tf.tile(predictions_extra_dim, data_tiles)

        # Compare predictions and threshold
        pred_is_pos = tf.greater(preds_tiled, thresh_tiled)

        # Tile labels by number of thresholds
        label_is_pos = tf.tile(labels_extra_dim, data_tiles)

        # Set sample weights
        if sample_weight is not None:
            sample_weight = weights_broadcast_ops.broadcast_weights(
                tf.cast(sample_weight, dtype=y_pred.dtype), y_pred)
            weights_tiled = tf.tile(tf.reshape(sample_weight, thresh_tiles),
                                    data_tiles)
        else:
            weights_tiled = None

        def _weighted_assign_add(label, pred, weights, var):
            label_and_pred = tf.cast(tf.logical_and(label, pred),
                                     dtype=y_pred.dtype)

            if weights is not None:
                label_and_pred *= weights

            if self.average != 'micro':
                label_and_pred = tf.reshape(label_and_pred,
                                            shape=[-1, self.num_classes])

            return var.assign_add(tf.reduce_sum(label_and_pred, self.axis))

        # Set return value
        update_ops = []

        # Update true positives
        update_ops.append(
            _weighted_assign_add(label_is_pos, pred_is_pos, weights_tiled,
                                 self.true_positives))

        # Update false negatives
        pred_is_neg = tf.logical_not(pred_is_pos)
        update_ops.append(
            _weighted_assign_add(label_is_pos, pred_is_neg, weights_tiled,
                                 self.false_negatives))

        # Update false positives
        label_is_neg = tf.logical_not(label_is_pos)
        update_ops.append(
            _weighted_assign_add(label_is_neg, pred_is_pos, weights_tiled,
                                 self.false_positives))

        return tf.group(update_ops)