Esempio n. 1
0
    def setUpClass(cls):
        logging.disable(logging.CRITICAL)
        cls.setup = yaml.load(Path('tests/data/config.yml').read_text(),
                              Loader=yaml.FullLoader)
        cls.RDN = RDN(arch_params=cls.setup['rdn'],
                      patch_size=cls.setup['patch_size'])

        cls.temp_data = Path('tests/temporary_test_data')
        cls.valid_files = cls.temp_data / 'valid_files'
        cls.valid_files.mkdir(parents=True, exist_ok=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.valid_files / item).touch()

        cls.invalid_files = cls.temp_data / 'invalid_files'
        cls.invalid_files.mkdir(parents=True, exist_ok=True)
        for item in ['data2.gif', 'data.data', 'data02']:
            (cls.invalid_files / item).touch()

        def nullifier(*args):
            pass

        cls.out_dir = cls.temp_data / 'out_dir'
        cls.predictor = Predictor(input_dir=str(cls.valid_files),
                                  output_dir=str(cls.out_dir))
        cls.predictor.logger = Mock(return_value=True)
Esempio n. 2
0
 def test_no_valid_images(self):
     try:
         predictor = Predictor(input_dir=str(self.invalid_files),
                               output_dir=str(self.out_dir))
     except ValueError as e:
         self.assertTrue('image' in str(e))
     else:
         self.assertTrue(False)
Esempio n. 3
0
    def test_no_valid_images(self):
        def invalid_folder(kind):
            return ['data2.gif', 'data1.extension', 'data0']

        with patch('os.listdir', side_effect=invalid_folder):
            with patch('os.mkdir', return_value=True):
                try:
                    cls.predictor = Predictor(input_dir='lr', output_dir='sr')
                except ValueError as e:
                    self.assertTrue('image' in str(e))
                else:
                    self.assertTrue(False)
Esempio n. 4
0
    def setUpClass(cls):
        logging.disable(logging.CRITICAL)
        cls.setup = yaml.load(
            open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RDN = RDN(arch_params=cls.setup['rdn'],
                      patch_size=cls.setup['patch_size'])

        def fake_folders(kind):
            return ['data2.gif', 'data1.png', 'data0.jpeg']

        def nullifier(*args):
            pass

        with patch('os.listdir', side_effect=fake_folders):
            with patch('os.mkdir', return_value=True):
                cls.predictor = Predictor(input_dir='dataname',
                                          output_dir='out_dir')
                cls.predictor.logger = Mock(return_value=True)
def run(config_file, default=False, training=False, prediction=False):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    logger = get_logger(__name__)
    session_type, generator, conf, dataset = setup(config_file, default, training, prediction)

    lr_patch_size = conf['session'][session_type]['patch_size']
    scale = conf['generators'][generator]['x']

    module = _get_module(generator)
    gen = module.make_model(conf['generators'][generator], lr_patch_size)
    if session_type == 'prediction':
        from ISR.predict.predictor import Predictor

        pr_h = Predictor(input_dir=conf['test_sets'][dataset])
        pr_h.get_predictions(gen, conf['weights_paths']['generator'])

    elif session_type == 'training':
        from ISR.train.trainer import Trainer

        hr_patch_size = lr_patch_size * scale
        if conf['default']['feat_ext']:
            from ISR.models.cut_vgg19 import Cut_VGG19

            out_layers = conf['feat_extr']['vgg19']['layers_to_extract']
            f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers)
        else:
            f_ext = None

        if conf['default']['discriminator']:
            from ISR.models.discriminator import Discriminator

            discr = Discriminator(patch_size=hr_patch_size, kernel_size=3)
        else:
            discr = None

        trainer = Trainer(
            generator=gen,
            discriminator=discr,
            feature_extractor=f_ext,
            lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'],
            hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'],
            lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'],
            hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'],
            loss_weights=conf['loss_weights'],
            dataname=conf['training_sets'][dataset]['data_name'],
            logs_dir=conf['dirs']['logs'],
            weights_dir=conf['dirs']['weights'],
            weights_generator=conf['weights_paths']['generator'],
            weights_discriminator=conf['weights_paths']['discriminator'],
            n_validation=conf['session'][session_type]['n_validation_samples'],
            lr_decay_frequency=conf['session'][session_type]['lr_decay_frequency'],
            lr_decay_factor=conf['session'][session_type]['lr_decay_factor'],
            T=0.01,
        )
        trainer.train(
            epochs=conf['session'][session_type]['epochs'],
            steps_per_epoch=conf['session'][session_type]['steps_per_epoch'],
            batch_size=conf['session'][session_type]['batch_size'],
        )

    else:
        logger.error('Invalid choice.')