Exemplo n.º 1
0
    def test_unbalanced_sub_sampling(self):
        labels = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 0])

        indices1 = unbalanced_sub_sampling(len(labels), num_samples=8, skip_samples=0)
        self.assertEqual(8, len(indices1))
        self.assertEqual(len(indices1), len(set(indices1)), "indices must be unique")

        indices2 = unbalanced_sub_sampling(len(labels), num_samples=8, skip_samples=2)
        self.assertEqual(8, len(indices2))
        self.assertEqual(len(indices2), len(set(indices2)), "indices must be unique")

        self.assertTrue(
            np.array_equal(indices1[2:], indices2[:-2]),
            "skipping samples should slide the window",
        )
Exemplo n.º 2
0
    def _init_image_and_label_subset(self):
        """
        If DATA_LIMIT = K >= 0, we reduce the size of the dataset from N to K.

        This function will create a mapping from [0, K) to [0, N), using the
        parameters specified in the DATA_LIMIT_SAMPLING configuration. This
        mapping is then cached and used for all __getitem__ calls to map
        the external indices from [0, K) to the internal [0, N) indices.

        This function makes the assumption that there is one data source only
        or that all data sources have the same length (same as __getitem__).
        """

        # Use one of the two random sampling strategies:
        # - unbalanced: random sampling is agnostic to labels
        # - balanced: makes sure all labels are equally represented
        if not self.data_limit_sampling.IS_BALANCED:
            self.image_and_label_subset = unbalanced_sub_sampling(
                total_num_samples=len(self.data_objs[0]),
                num_samples=self.data_limit,
                skip_samples=self.data_limit_sampling.SKIP_NUM_SAMPLES,
                seed=self.data_limit_sampling.SEED,
            )
        else:
            assert len(self.label_objs), "Balanced sampling requires labels"
            self.image_and_label_subset = balanced_sub_sampling(
                labels=self.label_objs[0],
                num_samples=self.data_limit,
                skip_samples=self.data_limit_sampling.SKIP_NUM_SAMPLES,
                seed=self.data_limit_sampling.SEED,
            )
        self._subset_initialized = True