예제 #1
0
 def test_getdataloaders_dataloader_constructor_dataset_arg_is_correct_type(self):
     def _dataloader_side_effect(*a, **kw):
         dataset, = a
         self.assertTrue(
             isinstance(dataset, self.imagefolder_class)
         )
     self.mock_dataloader.side_effect = _dataloader_side_effect
     train.getdataloaders('test_datadir', batch_size=1)
예제 #2
0
 def test_imagefolder_is_instantiated_with_tranforms(self):
     def _instantiate_side_effect(*a, **kw):
         self.assertIn('transform', kw)
         transform = kw['transform']
         self.assertIn(
             type(transform), [
                 torchvision.transforms.Compose
             ]
         )
     self.imagefolder_mock.side_effect = _instantiate_side_effect
     train.getdataloaders(self.test_datadir, batch_size=self.test_batch_size)
예제 #3
0
    def test_getdataloaders_dataloader_constructor_dataset_shuffle_arg(self):
        shuffles = []
        def _dataloader_side_effect(*a, **kw):
            if 'shuffle' in kw and kw['shuffle']:
                shuffles.append(True)
            else:
                shuffles.append(False)
        self.mock_dataloader.side_effect = _dataloader_side_effect

        train.getdataloaders('test_datadir', batch_size=1)
        self.assertTrue(any(shuffles))
예제 #4
0
 def test_imagefolder_is_instantiated_checking_path_arg(self):
     observed_paths = set()
     def _instantiate_side_effect(*a, **kw):
         path, = a
         observed_paths.add(path)
     self.imagefolder_mock.side_effect = _instantiate_side_effect
     
     train.getdataloaders(self.test_datadir, batch_size=self.test_batch_size)
     expected_paths = set([
         self.test_datadir + '/' + x for x in ('test', 'valid', 'train')
     ])
     self.assertEqual(
         expected_paths, observed_paths
     )
예제 #5
0
    def test_getdataloaders_dataloader_constructor_dataset_num_workers(self):
        num_workers = []
        def _dataloader_side_effect(*a, **kw):
            self.assertIn('num_workers', kw)
            num_workers.append(
                kw['num_workers']
            )
        self.mock_dataloader.side_effect = _dataloader_side_effect

        train.getdataloaders('test_datadir', batch_size=1)
        self.assertEqual(
            self.mock_dataloader.call_count,
            3
        )
        self.assertTrue(
            all([
                x == train._DATALOADER_NUM_WORKERS for x in num_workers
            ])
        )
예제 #6
0
    def test_getdataloaders_dataloader_constructor_batch_size_arg(self):
        batch_sizes = []
        def _dataloader_side_effect(*a, **kw):
            self.assertIn('batch_size', kw)
            batch_sizes.append(
                kw['batch_size']
            )
        self.mock_dataloader.side_effect = _dataloader_side_effect

        test_batch_size = unittest.mock.sentinel.batch_size
        train.getdataloaders('test_datadir', batch_size=test_batch_size)
        self.assertEqual(
            self.mock_dataloader.call_count,
            3
        )

        self.assertTrue(
            all([
                x == test_batch_size for x in batch_sizes
            ])
        )
예제 #7
0
 def test_imagefolder_is_instantiated(self):
     train.getdataloaders(self.test_datadir, batch_size=self.test_batch_size)
     self.assertEqual(self.imagefolder_mock.call_count, 3)
예제 #8
0
 def test_getdataloaders_retun_type(self):
     train_ret, val_ret, test_ret, = train.getdataloaders('test_datadir', batch_size=1)
     for r in [train_ret, val_ret, test_ret]:
         self.assertTrue(
             isinstance(r, self.dataloader_class)
         )