Esempio n. 1
0
 def test_split_for_reidentification_randomseed(self):
     '''
     Test randomseed for reidentification
     '''
     counts = {}
     config = dict()
     for i in range(10):
         label = "label%d" % i
         count = (i % 3 + 1) * 7
         counts[label] = count
         config[label] = {"attrs": None, "counts": count}
     source = self._generate_dataset(config)
     splits = [("train", 0.5), ("test", 0.5)]
     query = 0.4 / 0.7
     r1 = splitter.ReidentificationSplit(source, splits, query, seed=1234)
     r2 = splitter.ReidentificationSplit(source, splits, query, seed=1234)
     r3 = splitter.ReidentificationSplit(source, splits, query, seed=4321)
     self.assertEqual(list(r1.get_subset("train")),
                      list(r2.get_subset("train")))
     self.assertNotEqual(list(r1.get_subset("train")),
                         list(r3.get_subset("train")))
Esempio n. 2
0
    def test_split_for_reidentification_gives_error(self):
        query = 0.4 / 0.7  # valid query ratio

        with self.subTest("no label"):
            source = Dataset.from_iterable([
                DatasetItem(1, annotations=[]),
                DatasetItem(2, annotations=[]),
            ],
                                           categories=["a", "b", "c"])

            with self.assertRaisesRegex(Exception, "exactly one is expected"):
                splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
                actual = splitter.ReidentificationSplit(source, splits, query)
                len(actual.get_subset("train"))

        with self.subTest(msg="multi label"):
            source = Dataset.from_iterable([
                DatasetItem(1, annotations=[Label(0), Label(1)]),
                DatasetItem(2, annotations=[Label(0), Label(2)]),
            ],
                                           categories=["a", "b", "c"])

            with self.assertRaisesRegex(Exception, "exactly one is expected"):
                splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
                actual = splitter.ReidentificationSplit(source, splits, query)
                len(actual.get_subset("train"))

        counts = {i: (i % 3 + 1) * 7 for i in range(10)}
        config = {"person": {"attrs": ["PID"], "counts": counts}}
        source = self._generate_dataset(config)
        with self.subTest("wrong ratio"):
            with self.assertRaisesRegex(Exception, "in the range"):
                splits = [("train", -0.5), ("val", 0.2), ("test", 0.3)]
                splitter.ReidentificationSplit(source, splits, query)

            with self.assertRaisesRegex(Exception, "Sum of ratios"):
                splits = [("train", 0.6), ("val", 0.2), ("test", 0.3)]
                splitter.ReidentificationSplit(source, splits, query)

            with self.assertRaisesRegex(Exception, "in the range"):
                splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
                actual = splitter.ReidentificationSplit(source, splits, -query)

        with self.subTest("wrong subset name"):
            with self.assertRaisesRegex(Exception, "Subset name"):
                splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)]
                splitter.ReidentificationSplit(source, splits, query)

        with self.subTest("wrong attribute name for person id"):
            splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
            actual = splitter.ReidentificationSplit(source, splits, query)

            with self.assertRaisesRegex(Exception, "Unknown subset"):
                actual.get_subset("test")
Esempio n. 3
0
    def test_split_for_reidentification_rebalance(self):
        '''
        rebalance function shouldn't gives error when there's no exchange
        '''
        config = dict()
        for i in range(100):
            label = "label%03d" % i
            config[label] = {"attrs": None, "counts": 7}
        source = self._generate_dataset(config)
        splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
        query = 0.4 / 0.7
        actual = splitter.ReidentificationSplit(source, splits, query)

        self.assertEqual(350, len(actual.get_subset("train")))
        self.assertEqual(140, len(actual.get_subset("val")))
        self.assertEqual(90, len(actual.get_subset("test-gallery")))
        self.assertEqual(120, len(actual.get_subset("test-query")))
Esempio n. 4
0
    def test_split_for_reidentification(self):
        '''
        Test ReidentificationSplit using Dataset with label (ImageNet style)
        '''
        def _get_present(stat):
            values_present = []
            for label, dist in stat["distribution"].items():
                if dist[0] > 0:
                    values_present.append(label)
            return set(values_present)

        for with_attr in [True, False]:
            if with_attr:
                counts = {i: (i % 3 + 1) * 7 for i in range(10)}
                config = {"person": {"attrs": ["PID"], "counts": counts}}
                attr_for_id = "PID"
            else:
                counts = {}
                config = dict()
                for i in range(10):
                    label = "label%d" % i
                    count = (i % 3 + 1) * 7
                    counts[label] = count
                    config[label] = {"attrs": None, "counts": count}
                attr_for_id = None
            source = self._generate_dataset(config)
            splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
            query = 0.4 / 0.7
            actual = splitter.ReidentificationSplit(source, splits, query,
                                                    attr_for_id)

            stats = dict()
            for sname in ["train", "val", "test-query", "test-gallery"]:
                subset = actual.get_subset(sname)
                stat = compute_ann_statistics(subset)["annotations"]["labels"]
                if with_attr:
                    stat = stat["attributes"]["PID"]
                stats[sname] = stat

            # check size of subsets
            self.assertEqual(65, stats["train"]["count"])
            self.assertEqual(26, stats["val"]["count"])
            self.assertEqual(18, stats["test-gallery"]["count"])
            self.assertEqual(24, stats["test-query"]["count"])

            # check ID separation between test set and others
            train_ids = _get_present(stats["train"])
            test_ids = _get_present(stats["test-gallery"])
            for pid in train_ids:
                assert pid not in test_ids
            self.assertEqual(7, len(train_ids))
            self.assertEqual(3, len(test_ids))
            self.assertEqual(train_ids, _get_present(stats["val"]))
            self.assertEqual(test_ids, _get_present(stats["test-query"]))

            # check trainval set statistics
            trainval = stats["train"]["count"] + stats["val"]["count"]
            expected_train_count = int(trainval * 0.5 / 0.7)
            expected_val_count = int(trainval * 0.2 / 0.7)
            self.assertEqual(expected_train_count, stats["train"]["count"])
            self.assertEqual(expected_val_count, stats["val"]["count"])
            dist_train = stats["train"]["distribution"]
            dist_val = stats["val"]["distribution"]
            for pid in train_ids:
                total = counts[int(pid)] if with_attr else counts[pid]
                self.assertEqual(int(total * 0.5 / 0.7), dist_train[pid][0])
                self.assertEqual(int(total * 0.2 / 0.7), dist_val[pid][0])

            # check teset set statistics
            dist_gallery = stats["test-gallery"]["distribution"]
            dist_query = stats["test-query"]["distribution"]
            for pid in test_ids:
                total = counts[int(pid)] if with_attr else counts[pid]
                self.assertEqual(int(total * 0.3 / 0.7), dist_gallery[pid][0])
                self.assertEqual(int(total * 0.4 / 0.7), dist_query[pid][0])