def test_num_classes_respected(self): num_classes = test_utils.MAX_WAYS_UPPER_BOUND num_ways = test_utils.MIN_WAYS for _ in range(10): class_ids = sampling.sample_class_ids_uniformly( num_ways, num_classes) self.assertContainsSubset(class_ids, list(range(num_classes)))
def test_unique_class_ids(self): num_classes = test_utils.MAX_WAYS_UPPER_BOUND num_ways = test_utils.MIN_WAYS for _ in range(10): class_ids = sampling.sample_class_ids_uniformly( num_ways, num_classes) self.assertCountEqual(class_ids, set(class_ids))
def test_num_ways_respected(self): num_classes = MAX_WAYS_UPPER_BOUND num_ways = MIN_WAYS for _ in range(10): class_ids = sampling.sample_class_ids_uniformly(num_ways, num_classes) self.assertLen(set(class_ids), num_ways) self.assertLen(class_ids, num_ways)