def test_get_balanced_sampling_probability(self):
     instance_ids = self._get_instance_id_example()
     probabilities = sampling_utils.get_balanced_sampling_probability(
         instance_ids=instance_ids)
     expected_probabilities = np.array([[
         1.0 / 12.0, 1.0 / 12.0, 1.0 / 12.0, 1.0 / 8.0, 1.0 / 4.0,
         1.0 / 8.0, 1.0 / 4.0
     ], [1.0 / 4.0, 1.0 / 4.0, 0.1, 0.1, 0.1, 0.1, 0.1]],
                                       dtype=np.float32)
     self.assertAllClose(probabilities.numpy(), expected_probabilities)
 def test_get_balanced_sampling_probability_with_valid_mask(self):
     instance_ids = self._get_instance_id_example()
     valid_mask = self._get_valid_mask()
     probabilities = sampling_utils.get_balanced_sampling_probability(
         instance_ids=instance_ids,
         valid_mask=valid_mask,
         max_instance_id=4)
     expected_probabilities = np.array(
         [[1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 3.0, 1.0 / 3.0, 0.0, 0.0],
          [0.5, 0.0, 1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0]],
         dtype=np.float32)
     self.assertAllClose(probabilities.numpy(), expected_probabilities)