def test_evaluate_invalid_argument_exception_raised(
         self, error_msg, ground_truth, predictions, grid_size):
     with self.assertRaisesRegexp(
         (tf.errors.InvalidArgumentError, ValueError), error_msg):
         self.evaluate(
             intersection_over_union.evaluate(ground_truth, predictions,
                                              grid_size))
  def test_evaluate_preset(self, ground_truth, predictions, expected_iou):
    tensor_shape = random_tensor_shape()

    grid_size = np.array(ground_truth).ndim
    ground_truth_labels = np.tile(ground_truth, tensor_shape + [1] * grid_size)
    predicted_labels = np.tile(predictions, tensor_shape + [1] * grid_size)
    expected = np.tile(expected_iou, tensor_shape)

    result = intersection_over_union.evaluate(ground_truth_labels,
                                              predicted_labels, grid_size)

    self.assertAllClose(expected, result)