Beispiel #1
0
    def test_assert_binary_exception_raised(self, dtype):
        """Checks that assert_binary raises exceptions for invalid input."""
        tensor_size = np.random.randint(3) + 1
        tensor_shape = np.random.randint(1, 10, size=(tensor_size)).tolist()
        num_elements = np.prod(tensor_shape)

        # Vector with all ones except for a single negative entry.
        vector_with_negative = np.ones(num_elements)
        vector_with_negative[np.random.randint(num_elements)] = -1
        vector_with_negative = vector_with_negative.reshape(tensor_shape)
        vector_with_negative = tf.convert_to_tensor(value=vector_with_negative,
                                                    dtype=dtype)

        # Vector with all zeros except for a single 0.5 (or 2 in case dtype=int).
        vector = np.zeros(num_elements)
        vector[np.random.randint(num_elements)] = 2
        vector = vector.reshape(tensor_shape)
        vector = tf.convert_to_tensor(value=vector, dtype=dtype)
        vector = vector - tf.compat.v1.div(vector, 4) * 3

        with self.subTest(name="has_negative_number"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(asserts.assert_binary(vector_with_negative))

        with self.subTest(name="has_non_binary_number"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(asserts.assert_binary(vector))
def evaluate(ground_truth_labels: type_alias.TensorLike,
             predicted_labels: type_alias.TensorLike,
             grid_size: int = 1,
             name: str = "intersection_over_union_evaluate") -> tf.Tensor:
    """Computes the Intersection-Over-Union metric for the given ground truth and predicted labels.

  Note:
    In the following, A1 to An are optional batch dimensions, which must be
    broadcast compatible, and G1 to Gm are the grid dimensions.

  Args:
    ground_truth_labels: A tensor of shape `[A1, ..., An, G1, ..., Gm]`, where
      the last m axes represent a grid of ground truth attributes. Each
      attribute can either be 0 or 1.
    predicted_labels: A tensor of shape `[A1, ..., An, G1, ..., Gm]`, where the
      last m axes represent a grid of predicted attributes. Each attribute can
      either be 0 or 1.
    grid_size: The number of grid dimensions. Defaults to 1.
    name: A name for this op. Defaults to "intersection_over_union_evaluate".

  Returns:
    A tensor of shape `[A1, ..., An]` that stores the intersection-over-union
    metric of the given ground truth labels and predictions.

  Raises:
    ValueError: if the shape of `ground_truth_labels`, `predicted_labels` is
    not supported.
  """
    with tf.name_scope(name):
        ground_truth_labels = tf.convert_to_tensor(value=ground_truth_labels)
        predicted_labels = tf.convert_to_tensor(value=predicted_labels)

        shape.compare_batch_dimensions(tensors=(ground_truth_labels,
                                                predicted_labels),
                                       tensor_names=("ground_truth_labels",
                                                     "predicted_labels"),
                                       last_axes=-grid_size,
                                       broadcast_compatible=True)

        ground_truth_labels = asserts.assert_binary(ground_truth_labels)
        predicted_labels = asserts.assert_binary(predicted_labels)

        sum_ground_truth = tf.math.reduce_sum(input_tensor=ground_truth_labels,
                                              axis=list(range(-grid_size, 0)))
        sum_predictions = tf.math.reduce_sum(input_tensor=predicted_labels,
                                             axis=list(range(-grid_size, 0)))
        intersection = tf.math.reduce_sum(input_tensor=ground_truth_labels *
                                          predicted_labels,
                                          axis=list(range(-grid_size, 0)))
        union = sum_ground_truth + sum_predictions - intersection

        return tf.where(tf.math.equal(union, 0), tf.ones_like(union),
                        intersection / union)
Beispiel #3
0
    def test_assert_binary_passthrough(self):
        """Checks that the assert is a passthrough when the flag is False."""
        vector_input = _pick_random_vector()

        vector_output = asserts.assert_binary(vector_input)

        self.assertIs(vector_input, vector_output)