def test_complains_on_zero_remaining(self): num_images_per_class = np.array([10, 10, 10, 10, 10]) num_remaining_per_class = np.array([5, 0, 5, 5, 5]) support_set_size = 5 with self.assertRaises(ValueError): sampling.sample_num_support_per_class( num_images_per_class, num_remaining_per_class, support_set_size, min_log_weight=test_utils.MIN_LOG_WEIGHT, max_log_weight=test_utils.MAX_LOG_WEIGHT)
def test_at_least_one_example_per_class(self): num_images_per_class = np.array([10, 10, 10, 10, 10]) num_remaining_per_class = np.array([5, 5, 5, 5, 5]) support_set_size = 5 for _ in range(10): num_support_per_class = sampling.sample_num_support_per_class( num_images_per_class, num_remaining_per_class, support_set_size, min_log_weight=test_utils.MIN_LOG_WEIGHT, max_log_weight=test_utils.MAX_LOG_WEIGHT) self.assertTrue((num_support_per_class > 0).any())
def test_support_set_size_respected(self): num_images_per_class = np.array([50, 40, 30, 20]) num_remaining_per_class = np.array([40, 30, 20, 10]) support_set_size = 50 for _ in range(10): num_support_per_class = sampling.sample_num_support_per_class( num_images_per_class, num_remaining_per_class, support_set_size, min_log_weight=test_utils.MIN_LOG_WEIGHT, max_log_weight=test_utils.MAX_LOG_WEIGHT) self.assertLessEqual(num_support_per_class.sum(), support_set_size)