def test_getdataloaders_dataloader_constructor_dataset_arg_is_correct_type(self): def _dataloader_side_effect(*a, **kw): dataset, = a self.assertTrue( isinstance(dataset, self.imagefolder_class) ) self.mock_dataloader.side_effect = _dataloader_side_effect train.getdataloaders('test_datadir', batch_size=1)
def test_imagefolder_is_instantiated_with_tranforms(self): def _instantiate_side_effect(*a, **kw): self.assertIn('transform', kw) transform = kw['transform'] self.assertIn( type(transform), [ torchvision.transforms.Compose ] ) self.imagefolder_mock.side_effect = _instantiate_side_effect train.getdataloaders(self.test_datadir, batch_size=self.test_batch_size)
def test_getdataloaders_dataloader_constructor_dataset_shuffle_arg(self): shuffles = [] def _dataloader_side_effect(*a, **kw): if 'shuffle' in kw and kw['shuffle']: shuffles.append(True) else: shuffles.append(False) self.mock_dataloader.side_effect = _dataloader_side_effect train.getdataloaders('test_datadir', batch_size=1) self.assertTrue(any(shuffles))
def test_imagefolder_is_instantiated_checking_path_arg(self): observed_paths = set() def _instantiate_side_effect(*a, **kw): path, = a observed_paths.add(path) self.imagefolder_mock.side_effect = _instantiate_side_effect train.getdataloaders(self.test_datadir, batch_size=self.test_batch_size) expected_paths = set([ self.test_datadir + '/' + x for x in ('test', 'valid', 'train') ]) self.assertEqual( expected_paths, observed_paths )
def test_getdataloaders_dataloader_constructor_dataset_num_workers(self): num_workers = [] def _dataloader_side_effect(*a, **kw): self.assertIn('num_workers', kw) num_workers.append( kw['num_workers'] ) self.mock_dataloader.side_effect = _dataloader_side_effect train.getdataloaders('test_datadir', batch_size=1) self.assertEqual( self.mock_dataloader.call_count, 3 ) self.assertTrue( all([ x == train._DATALOADER_NUM_WORKERS for x in num_workers ]) )
def test_getdataloaders_dataloader_constructor_batch_size_arg(self): batch_sizes = [] def _dataloader_side_effect(*a, **kw): self.assertIn('batch_size', kw) batch_sizes.append( kw['batch_size'] ) self.mock_dataloader.side_effect = _dataloader_side_effect test_batch_size = unittest.mock.sentinel.batch_size train.getdataloaders('test_datadir', batch_size=test_batch_size) self.assertEqual( self.mock_dataloader.call_count, 3 ) self.assertTrue( all([ x == test_batch_size for x in batch_sizes ]) )
def test_imagefolder_is_instantiated(self): train.getdataloaders(self.test_datadir, batch_size=self.test_batch_size) self.assertEqual(self.imagefolder_mock.call_count, 3)
def test_getdataloaders_retun_type(self): train_ret, val_ret, test_ret, = train.getdataloaders('test_datadir', batch_size=1) for r in [train_ret, val_ret, test_ret]: self.assertTrue( isinstance(r, self.dataloader_class) )