コード例 #1
0
    def test_simple_datasets(self):
        opt = Options()
        opt.dataset_name = os.path.join(os.getcwd(), "test")
        opt.dataroot = os.path.join(os.getcwd(), "test")

        class SimpleDataset(BaseDataset):
            def __init__(self, dataset_opt):
                super(SimpleDataset, self).__init__(dataset_opt)

                self.train_dataset = CustomMockDataset(10, 1, 3, 10)
                self.test_dataset = CustomMockDataset(10, 1, 3, 10)

        dataset = SimpleDataset(opt)

        model_config = MockModelConfig()
        model_config.conv_type = "dense"
        model = MockModel(model_config)
        dataset.create_dataloaders(model, 5, True, 0, False)

        self.assertEqual(dataset.pre_transform, None)
        self.assertEqual(dataset.test_transform, None)
        self.assertEqual(dataset.train_transform, None)
        self.assertEqual(dataset.val_transform, None)
        self.assertNotEqual(dataset.train_dataset, None)
        self.assertNotEqual(dataset.test_dataset, None)
        self.assertTrue(dataset.has_test_loaders)
        self.assertFalse(dataset.has_val_loader)
コード例 #2
0
    def test_multiple_test_datasets(self):
        opt = Options()
        opt.dataset_name = os.path.join(os.getcwd(), "test")
        opt.dataroot = os.path.join(os.getcwd(), "test")

        class MultiTestDataset(BaseDataset):
            def __init__(self, dataset_opt):
                super(MultiTestDataset, self).__init__(dataset_opt)

                self.train_dataset = CustomMockDataset(10, 1, 3, 10)
                self.val_dataset = CustomMockDataset(10, 1, 3, 10)
                self.test_dataset = [
                    CustomMockDataset(10, 1, 3, 10),
                    CustomMockDataset(10, 1, 3, 20)
                ]

        dataset = MultiTestDataset(opt)

        model_config = MockModelConfig()
        model_config.conv_type = "dense"
        model = MockModel(model_config)
        dataset.create_dataloaders(model, 5, True, 0, False)

        loaders = dataset.test_dataloaders
        self.assertEqual(len(loaders), 2)
        self.assertEqual(len(loaders[0].dataset), 10)
        self.assertEqual(len(loaders[1].dataset), 20)
        self.assertEqual(dataset.num_classes, 3)
        self.assertEqual(dataset.is_hierarchical, False)
        self.assertEqual(dataset.has_fixed_points_transform, False)
        self.assertEqual(dataset.has_val_loader, True)
        self.assertEqual(dataset.class_to_segments, None)
        self.assertEqual(dataset.feature_dimension, 1)

        batch = next(iter(loaders[0]))
        num_samples = BaseDataset.get_num_samples(batch, "dense")
        self.assertEqual(num_samples, 5)

        sample = BaseDataset.get_sample(batch, "pos", 1, "dense")
        self.assertEqual(sample.shape, (10, 3))
        sample = BaseDataset.get_sample(batch, "x", 1, "dense")
        self.assertEqual(sample.shape, (10, 1))
        self.assertEqual(dataset.num_batches, {
            "train": 2,
            "val": 2,
            "test_0": 2,
            "test_1": 4
        })

        repr = "Dataset: MultiTestDataset \n\x1b[0;95mpre_transform \x1b[0m= None\n\x1b[0;95mtest_transform \x1b[0m= None\n\x1b[0;95mtrain_transform \x1b[0m= None\n\x1b[0;95mval_transform \x1b[0m= None\n\x1b[0;95minference_transform \x1b[0m= None\nSize of \x1b[0;95mtrain_dataset \x1b[0m= 10\nSize of \x1b[0;95mtest_dataset \x1b[0m= 10, 20\nSize of \x1b[0;95mval_dataset \x1b[0m= 10\n\x1b[0;95mBatch size =\x1b[0m 5"
        self.assertEqual(dataset.__repr__(), repr)
コード例 #3
0
    def test_normal(self):
        dataset_opt = MockDatasetConfig()
        setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset"))

        mock_base_dataset = MockBaseDataset(dataset_opt)
        mock_base_dataset.test_dataset = MockDataset()
        model_config = MockModelConfig()
        setattr(model_config, "conv_type", "dense")
        model = MockModel(model_config)

        mock_base_dataset.create_dataloaders(model, 2, True, 0, False)
        datasets = mock_base_dataset.test_dataloaders

        self.assertEqual(len(datasets), 1)