def test_get_validation_batches_valid_request(self): patch_size = 3 with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): with patch('os.listdir', side_effect=self.fake_folders): DH = DataHandler( lr_dir={ 'res': 'lr', 'matching': True }, hr_dir={ 'res': 'hr', 'matching': True }, patch_size=patch_size, scale=2, n_validation_samples=2, ) with patch('imageio.imread', side_effect=self.image_getter): with patch('os.path.join', side_effect=self.path_giver): batch = DH.get_validation_batches(batch_size=12) self.assertTrue(len(batch) is 2) self.assertTrue(type(batch) is list) self.assertTrue(type(batch[0]) is dict) self.assertTrue(batch[0]['hr'].shape == (12, patch_size * 2, patch_size * 2, 3)) self.assertTrue(batch[0]['lr'].shape == (12, patch_size, patch_size, 3)) self.assertTrue(batch[1]['hr'].shape == (12, patch_size * 2, patch_size * 2, 3)) self.assertTrue(batch[1]['lr'].shape == (12, patch_size, patch_size, 3))
def test_get_validation_batches_invalid_number_of_samples(self): patch_size = 3 with patch('os.listdir', side_effect=self.fake_folders): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir={ 'res': 'lr', 'matching': True }, hr_dir={ 'res': 'hr', 'matching': True }, patch_size=patch_size, scale=2, n_validation_samples=None, ) with patch('imageio.imread', side_effect=self.image_getter): with patch('os.path.join', side_effect=self.path_giver): try: with patch('raise', None): batch = DH.get_validation_batches(batch_size=5) except: self.assertTrue(True) else: self.assertTrue(False)