Exemplo n.º 1
0
class TestDatasetFactory:

    # DatasetFactory config
    config = Config.Dataset()
    config.name = 'svhn'
    config.imbalance_ratio = 10
    config.validation_split = 0.3
    config.batch_size = 100
    config.classes = [0, 1]
    gan_model = Config.Dataset.Model()
    gan_model.name = 'wgan'
    gan_model.saved_model = ''
    gan_model.output_size = 32
    gan_model.input_channels = 3
    config.gan_model = gan_model

    @pytest.fixture(scope="class")
    def dataset_factory(self):
        self.config.oversampling = 'none'
        train_loader, valid_loader, test_loader = DatasetFactory.create(
            dataset_config=self.config)
        yield train_loader, valid_loader, test_loader

    @pytest.fixture(scope="class")
    def dataset_factory_balanced(self):
        self.config.oversampling = 'gan'
        train_loader, valid_loader, test_loader = DatasetFactory.create(
            dataset_config=self.config)
        yield train_loader, valid_loader, test_loader

    def test_dataset_has_proper_number_of_examples(self, dataset_factory):
        import random
        import torch
        import numpy as np

        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        seeded_1_train_loader, seeded_1_valid_loader, seeded_1_test_loader = dataset_factory
        seeded_2_train_loader, seeded_2_valid_loader, seeded_2_test_loader = dataset_factory

        assert len(seeded_1_train_loader.dataset) == len(
            seeded_2_train_loader.dataset)
        assert len(seeded_1_valid_loader.dataset) == len(
            seeded_2_valid_loader.dataset)
        assert len(seeded_1_test_loader.dataset) == len(
            seeded_2_test_loader.dataset)

        s1_dataloader_iter = iter(seeded_1_valid_loader)
        s2_dataloader_iter = iter(seeded_2_valid_loader)
        while True:
            try:
                _, s1_valid_labels = next(s1_dataloader_iter)
                _, s2_valid_labels = next(s2_dataloader_iter)
            except:
                break
        _, s1_last_batch_counts = np.unique(s1_valid_labels,
                                            return_counts=True)
        _, s2_last_batch_counts = np.unique(s2_valid_labels,
                                            return_counts=True)

        # All diffrent
        assert len(s1_valid_labels) == len(s2_valid_labels)

    def test_dataset_has_proper_shape(self, dataset_factory):
        train_loader, valid_loader, test_loader = dataset_factory
        train_data, train_labels = next(iter(train_loader))
        valid_data, valid_labels = next(iter(valid_loader))
        test_data, test_labels = next(iter(test_loader))

        assert train_data.shape == torch.Size(
            [self.config.batch_size, 3, 32, 32])
        assert valid_data.shape == torch.Size(
            [self.config.batch_size, 3, 32, 32])
        assert test_data.shape == torch.Size(
            [self.config.batch_size, 3, 32, 32])
        assert train_labels.shape == torch.Size([self.config.batch_size])
        assert valid_labels.shape == torch.Size([self.config.batch_size])
        assert test_labels.shape == torch.Size([self.config.batch_size])

    def test_dataset_has_proper_classes(self, dataset_factory):
        train_loader, valid_loader, test_loader = dataset_factory

        # Check that only the right classes are selected
        for loader in (train_loader, valid_loader, test_loader):
            for _, labels in train_loader:
                for label in labels:
                    assert label in self.config.classes

    def test_dataset_has_proper_imbalance(self, dataset_factory):
        train_loader, valid_loader, _ = dataset_factory

        label_count = {self.config.classes[0]: 0, self.config.classes[1]: 1}
        for loader in (train_loader, valid_loader):
            for _, labels in loader:
                for label in labels:
                    label_count[label.item()] += 1

        observed_imbalance_ratio = label_count[
            self.config.classes[0]] / label_count[self.config.classes[1]]
        assert observed_imbalance_ratio == pytest.approx(
            self.config.imbalance_ratio, abs=1e-1)

    def test_test_dataset_has_no_imbalance(self, dataset_factory):
        _, _, test_loader = dataset_factory

        label_count = {self.config.classes[0]: 0, self.config.classes[1]: 1}
        for _, labels in test_loader:
            for label in labels:
                label_count[label.item()] += 1

        observed_imbalance_ratio = label_count[
            self.config.classes[0]] / label_count[self.config.classes[1]]
        assert observed_imbalance_ratio == pytest.approx(1, abs=1e-1)

    @gpu_only
    def test_oversampling_gan_is_balanced(self, dataset_factory_balanced):
        train_loader, valid_loader, _ = dataset_factory_balanced

        for loader in (train_loader, valid_loader):
            label_count = {
                self.config.classes[0]: 0,
                self.config.classes[1]: 1
            }
            for _, labels in loader:
                for label in labels:
                    label_count[label.item()] += 1

            observed_imbalance_ratio = label_count[
                self.config.classes[0]] / label_count[self.config.classes[1]]
            assert observed_imbalance_ratio == pytest.approx(1, abs=1e-1)