Ejemplo n.º 1
0
    def test_split_for_classification_gives_error(self):
        with self.subTest("no label"):
            source = Dataset.from_iterable([
                DatasetItem(1, annotations=[]),
                DatasetItem(2, annotations=[]),
            ],
                                           categories=["a", "b", "c"])

            with self.assertRaisesRegex(Exception, "exactly one is expected"):
                splits = [("train", 0.7), ("test", 0.3)]
                actual = splitter.ClassificationSplit(source, splits)
                len(actual.get_subset("train"))

        with self.subTest("multi label"):
            source = Dataset.from_iterable([
                DatasetItem(1, annotations=[Label(0), Label(1)]),
                DatasetItem(2, annotations=[Label(0), Label(2)]),
            ],
                                           categories=["a", "b", "c"])

            with self.assertRaisesRegex(Exception, "exactly one is expected"):
                splits = [("train", 0.7), ("test", 0.3)]
                splitter.ClassificationSplit(source, splits)
                len(actual.get_subset("train"))

        source = Dataset.from_iterable([
            DatasetItem(1, annotations=[Label(0)]),
            DatasetItem(2, annotations=[Label(1)]),
        ],
                                       categories=["a", "b", "c"])

        with self.subTest("wrong ratio"):
            with self.assertRaisesRegex(Exception, "in the range"):
                splits = [("train", -0.5), ("test", 1.5)]
                splitter.ClassificationSplit(source, splits)

            with self.assertRaisesRegex(Exception, "Sum of ratios"):
                splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)]
                splitter.ClassificationSplit(source, splits)

        with self.subTest("wrong subset name"):
            with self.assertRaisesRegex(Exception, "Subset name"):
                splits = [("train_", 0.5), ("val", 0.2), ("test", 0.3)]
                splitter.ClassificationSplit(source, splits)
Ejemplo n.º 2
0
    def test_split_for_classification_zero_ratio(self):
        config = {
            "label1": {
                "attrs": None,
                "counts": 5
            },
        }
        source = self._generate_dataset(config)
        splits = [("train", 0.1), ("val", 0.9), ("test", 0.0)]

        actual = splitter.ClassificationSplit(source, splits)

        self.assertEqual(1, len(actual.get_subset("train")))
        self.assertEqual(4, len(actual.get_subset("val")))
        self.assertEqual(0, len(actual.get_subset("test")))
Ejemplo n.º 3
0
    def test_split_for_classification_single_class_single_attr(self):
        counts = {0: 10, 1: 20, 2: 30}
        config = {"label": {"attrs": ["attr"], "counts": counts}}
        source = self._generate_dataset(config)

        splits = [("train", 0.7), ("test", 0.3)]
        actual = splitter.ClassificationSplit(source, splits)

        self.assertEqual(42, len(actual.get_subset("train")))
        self.assertEqual(18, len(actual.get_subset("test")))

        # check stats for train
        stat_train = compute_ann_statistics(actual.get_subset("train"))
        attr_train = stat_train["annotations"]["labels"]["attributes"]
        self.assertEqual(7, attr_train["attr"]["distribution"]["0"][0])
        self.assertEqual(14, attr_train["attr"]["distribution"]["1"][0])
        self.assertEqual(21, attr_train["attr"]["distribution"]["2"][0])

        # check stats for test
        stat_test = compute_ann_statistics(actual.get_subset("test"))
        attr_test = stat_test["annotations"]["labels"]["attributes"]
        self.assertEqual(3, attr_test["attr"]["distribution"]["0"][0])
        self.assertEqual(6, attr_test["attr"]["distribution"]["1"][0])
        self.assertEqual(9, attr_test["attr"]["distribution"]["2"][0])
Ejemplo n.º 4
0
    def test_split_for_classification_multi_class_no_attr(self):
        config = {
            "label1": {
                "attrs": None,
                "counts": 10
            },
            "label2": {
                "attrs": None,
                "counts": 20
            },
            "label3": {
                "attrs": None,
                "counts": 30
            },
        }
        source = self._generate_dataset(config)

        splits = [("train", 0.7), ("test", 0.3)]
        actual = splitter.ClassificationSplit(source, splits)

        self.assertEqual(42, len(actual.get_subset("train")))
        self.assertEqual(18, len(actual.get_subset("test")))

        # check stats for train
        stat_train = compute_ann_statistics(actual.get_subset("train"))
        dist_train = stat_train["annotations"]["labels"]["distribution"]
        self.assertEqual(7, dist_train["label1"][0])
        self.assertEqual(14, dist_train["label2"][0])
        self.assertEqual(21, dist_train["label3"][0])

        # check stats for test
        stat_test = compute_ann_statistics(actual.get_subset("test"))
        dist_test = stat_test["annotations"]["labels"]["distribution"]
        self.assertEqual(3, dist_test["label1"][0])
        self.assertEqual(6, dist_test["label2"][0])
        self.assertEqual(9, dist_test["label3"][0])
Ejemplo n.º 5
0
    def test_split_for_classification_single_class_multi_attr(self):
        counts = {
            (0, 0): 20,
            (0, 1): 20,
            (0, 2): 30,
            (1, 0): 20,
            (1, 1): 10,
            (1, 2): 20,
        }
        attrs = ["attr1", "attr2"]
        config = {"label": {"attrs": attrs, "counts": counts}}
        source = self._generate_dataset(config)

        splits = [("train", 0.7), ("test", 0.3)]
        actual = splitter.ClassificationSplit(source, splits)

        self.assertEqual(84, len(actual.get_subset("train")))
        self.assertEqual(36, len(actual.get_subset("test")))

        # check stats for train
        stat_train = compute_ann_statistics(actual.get_subset("train"))
        attr_train = stat_train["annotations"]["labels"]["attributes"]
        self.assertEqual(49, attr_train["attr1"]["distribution"]["0"][0])
        self.assertEqual(35, attr_train["attr1"]["distribution"]["1"][0])
        self.assertEqual(28, attr_train["attr2"]["distribution"]["0"][0])
        self.assertEqual(21, attr_train["attr2"]["distribution"]["1"][0])
        self.assertEqual(35, attr_train["attr2"]["distribution"]["2"][0])

        # check stats for test
        stat_test = compute_ann_statistics(actual.get_subset("test"))
        attr_test = stat_test["annotations"]["labels"]["attributes"]
        self.assertEqual(21, attr_test["attr1"]["distribution"]["0"][0])
        self.assertEqual(15, attr_test["attr1"]["distribution"]["1"][0])
        self.assertEqual(12, attr_test["attr2"]["distribution"]["0"][0])
        self.assertEqual(9, attr_test["attr2"]["distribution"]["1"][0])
        self.assertEqual(15, attr_test["attr2"]["distribution"]["2"][0])
Ejemplo n.º 6
0
    def test_split_for_classification_multi_label_with_attr(self):
        counts = {
            (0, 0): 20,
            (0, 1): 20,
            (0, 2): 30,
            (1, 0): 20,
            (1, 1): 10,
            (1, 2): 20,
        }
        attr1 = ["attr1", "attr2"]
        attr2 = ["attr1", "attr3"]
        config = {
            "label1": {
                "attrs": attr1,
                "counts": counts
            },
            "label2": {
                "attrs": attr2,
                "counts": counts
            },
        }
        source = self._generate_dataset(config)

        splits = [("train", 0.7), ("test", 0.3)]
        actual = splitter.ClassificationSplit(source, splits)

        train = actual.get_subset("train")
        test = actual.get_subset("test")
        self.assertEqual(168, len(train))
        self.assertEqual(72, len(test))

        # check stats for train
        stat_train = compute_ann_statistics(train)
        dist_train = stat_train["annotations"]["labels"]["distribution"]
        self.assertEqual(84, dist_train["label1"][0])
        self.assertEqual(84, dist_train["label2"][0])
        attr_train = stat_train["annotations"]["labels"]["attributes"]
        self.assertEqual(49 * 2, attr_train["attr1"]["distribution"]["0"][0])
        self.assertEqual(35 * 2, attr_train["attr1"]["distribution"]["1"][0])
        self.assertEqual(28, attr_train["attr2"]["distribution"]["0"][0])
        self.assertEqual(21, attr_train["attr2"]["distribution"]["1"][0])
        self.assertEqual(35, attr_train["attr2"]["distribution"]["2"][0])
        self.assertEqual(28, attr_train["attr3"]["distribution"]["0"][0])
        self.assertEqual(21, attr_train["attr3"]["distribution"]["1"][0])
        self.assertEqual(35, attr_train["attr3"]["distribution"]["2"][0])

        # check stats for test
        stat_test = compute_ann_statistics(test)
        dist_test = stat_test["annotations"]["labels"]["distribution"]
        self.assertEqual(36, dist_test["label1"][0])
        self.assertEqual(36, dist_test["label2"][0])
        attr_test = stat_test["annotations"]["labels"]["attributes"]
        self.assertEqual(21 * 2, attr_test["attr1"]["distribution"]["0"][0])
        self.assertEqual(15 * 2, attr_test["attr1"]["distribution"]["1"][0])
        self.assertEqual(12, attr_test["attr2"]["distribution"]["0"][0])
        self.assertEqual(9, attr_test["attr2"]["distribution"]["1"][0])
        self.assertEqual(15, attr_test["attr2"]["distribution"]["2"][0])
        self.assertEqual(12, attr_test["attr3"]["distribution"]["0"][0])
        self.assertEqual(9, attr_test["attr3"]["distribution"]["1"][0])
        self.assertEqual(15, attr_test["attr3"]["distribution"]["2"][0])

        with self.subTest("random seed test"):
            r1 = splitter.ClassificationSplit(source, splits, seed=1234)
            r2 = splitter.ClassificationSplit(source, splits, seed=1234)
            r3 = splitter.ClassificationSplit(source, splits, seed=4321)
            self.assertEqual(list(r1.get_subset("test")),
                             list(r2.get_subset("test")))
            self.assertNotEqual(list(r1.get_subset("test")),
                                list(r3.get_subset("test")))