def setUpClass(cls): logging.disable(logging.CRITICAL) cls.setup = yaml.load(Path('tests/data/config.yml').read_text(), Loader=yaml.FullLoader) cls.RDN = RDN(arch_params=cls.setup['rdn'], patch_size=cls.setup['patch_size']) cls.temp_data = Path('tests/temporary_test_data') cls.valid_files = cls.temp_data / 'valid_files' cls.valid_files.mkdir(parents=True, exist_ok=True) for item in ['data2.gif', 'data1.png', 'data0.jpeg']: (cls.valid_files / item).touch() cls.invalid_files = cls.temp_data / 'invalid_files' cls.invalid_files.mkdir(parents=True, exist_ok=True) for item in ['data2.gif', 'data.data', 'data02']: (cls.invalid_files / item).touch() def nullifier(*args): pass cls.out_dir = cls.temp_data / 'out_dir' cls.predictor = Predictor(input_dir=str(cls.valid_files), output_dir=str(cls.out_dir)) cls.predictor.logger = Mock(return_value=True)
def test_no_valid_images(self): try: predictor = Predictor(input_dir=str(self.invalid_files), output_dir=str(self.out_dir)) except ValueError as e: self.assertTrue('image' in str(e)) else: self.assertTrue(False)
def test_no_valid_images(self): def invalid_folder(kind): return ['data2.gif', 'data1.extension', 'data0'] with patch('os.listdir', side_effect=invalid_folder): with patch('os.mkdir', return_value=True): try: cls.predictor = Predictor(input_dir='lr', output_dir='sr') except ValueError as e: self.assertTrue('image' in str(e)) else: self.assertTrue(False)
def setUpClass(cls): logging.disable(logging.CRITICAL) cls.setup = yaml.load( open(os.path.join('tests', 'data', 'config.yml'), 'r')) cls.RDN = RDN(arch_params=cls.setup['rdn'], patch_size=cls.setup['patch_size']) def fake_folders(kind): return ['data2.gif', 'data1.png', 'data0.jpeg'] def nullifier(*args): pass with patch('os.listdir', side_effect=fake_folders): with patch('os.mkdir', return_value=True): cls.predictor = Predictor(input_dir='dataname', output_dir='out_dir') cls.predictor.logger = Mock(return_value=True)
def run(config_file, default=False, training=False, prediction=False): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' logger = get_logger(__name__) session_type, generator, conf, dataset = setup(config_file, default, training, prediction) lr_patch_size = conf['session'][session_type]['patch_size'] scale = conf['generators'][generator]['x'] module = _get_module(generator) gen = module.make_model(conf['generators'][generator], lr_patch_size) if session_type == 'prediction': from ISR.predict.predictor import Predictor pr_h = Predictor(input_dir=conf['test_sets'][dataset]) pr_h.get_predictions(gen, conf['weights_paths']['generator']) elif session_type == 'training': from ISR.train.trainer import Trainer hr_patch_size = lr_patch_size * scale if conf['default']['feat_ext']: from ISR.models.cut_vgg19 import Cut_VGG19 out_layers = conf['feat_extr']['vgg19']['layers_to_extract'] f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers) else: f_ext = None if conf['default']['discriminator']: from ISR.models.discriminator import Discriminator discr = Discriminator(patch_size=hr_patch_size, kernel_size=3) else: discr = None trainer = Trainer( generator=gen, discriminator=discr, feature_extractor=f_ext, lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'], hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'], lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'], hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'], loss_weights=conf['loss_weights'], dataname=conf['training_sets'][dataset]['data_name'], logs_dir=conf['dirs']['logs'], weights_dir=conf['dirs']['weights'], weights_generator=conf['weights_paths']['generator'], weights_discriminator=conf['weights_paths']['discriminator'], n_validation=conf['session'][session_type]['n_validation_samples'], lr_decay_frequency=conf['session'][session_type]['lr_decay_frequency'], lr_decay_factor=conf['session'][session_type]['lr_decay_factor'], T=0.01, ) trainer.train( epochs=conf['session'][session_type]['epochs'], steps_per_epoch=conf['session'][session_type]['steps_per_epoch'], batch_size=conf['session'][session_type]['batch_size'], ) else: logger.error('Invalid choice.')