def _init_image_and_label_subset(self): """ If DATA_LIMIT = K >= 0, we reduce the size of the dataset from N to K. This function will create a mapping from [0, K) to [0, N), using the parameters specified in the DATA_LIMIT_SAMPLING configuration. This mapping is then cached and used for all __getitem__ calls to map the external indices from [0, K) to the internal [0, N) indices. This function makes the assumption that there is one data source only or that all data sources have the same length (same as __getitem__). """ # Use one of the two random sampling strategies: # - unbalanced: random sampling is agnostic to labels # - balanced: makes sure all labels are equally represented if not self.data_limit_sampling.IS_BALANCED: self.image_and_label_subset = unbalanced_sub_sampling( total_num_samples=len(self.data_objs[0]), num_samples=self.data_limit, skip_samples=self.data_limit_sampling.SKIP_NUM_SAMPLES, seed=self.data_limit_sampling.SEED, ) else: assert len(self.label_objs), "Balanced sampling requires labels" self.image_and_label_subset = balanced_sub_sampling( labels=self.label_objs[0], num_samples=self.data_limit, skip_samples=self.data_limit_sampling.SKIP_NUM_SAMPLES, seed=self.data_limit_sampling.SEED, ) self._subset_initialized = True
def test_balanced_sub_sampling(self): labels = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 0]) unique_labels = set(labels) indices1 = balanced_sub_sampling(labels, num_samples=8, skip_samples=0) values, counts = np.unique(labels[indices1], return_counts=True) self.assertEqual(8, len(indices1)) self.assertEqual( set(values), set(unique_labels), "at least one of each label should be selected", ) self.assertEqual(2, np.min(counts), "at least two of each label is selected") self.assertEqual(2, np.max(counts), "at most two of each label is selected") indices2 = balanced_sub_sampling(labels, num_samples=8, skip_samples=4) self.assertEqual(8, len(indices2)) self.assertEqual( 4, len(set(indices1) & set(indices2)), "skipping samples should slide the window", )