def test_setup_default_prediction(self):
        base_conf = {}
        base_conf['default'] = {
            'generator': 'rdn',
            'feature_extractor': False,
            'discriminator': False,
            'training_set': 'div2k-x4',
            'test_set': 'dummy',
        }
        base_conf['generators'] = {
            'rdn': {
                'C': None,
                'D': None,
                'G': None,
                'G0': None,
                'x': None
            }
        }
        base_conf['weights_paths'] = {
            'generator': os.path.join('a', 'path', 'to', 'rdn-C3-D1-G7-G05-x2')
        }
        training = False
        prediction = True
        default = True

        with patch('yaml.load', return_value=base_conf):
            session_type, generator, conf, dataset = utils.setup(
                'tests/data/config.yml', default, training, prediction)
        self.assertTrue(session_type == 'prediction')
        self.assertTrue(generator == 'rdn')
        self.assertTrue(conf == base_conf)
        self.assertTrue(dataset == 'dummy')
    def test_setup_default_training(self):
        base_conf = {}
        base_conf['default'] = {
            'generator': 'rrdn',
            'feature_extractor': False,
            'discriminator': False,
            'training_set': 'div2k-x4',
            'test_set': 'dummy',
        }
        training = True
        prediction = False
        default = True

        with patch('yaml.load', return_value=base_conf) as import_module:
            session_type, generator, conf, dataset = utils.setup(
                'tests/data/config.yml', default, training, prediction)
        self.assertTrue(session_type == 'training')
        self.assertTrue(generator == 'rrdn')
        self.assertTrue(conf == base_conf)
        self.assertTrue(dataset == 'div2k-x4')
def run(config_file, default=False, training=False, prediction=False):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    logger = get_logger(__name__)
    session_type, generator, conf, dataset = setup(config_file, default, training, prediction)

    lr_patch_size = conf['session'][session_type]['patch_size']
    scale = conf['generators'][generator]['x']

    module = _get_module(generator)
    gen = module.make_model(conf['generators'][generator], lr_patch_size)
    if session_type == 'prediction':
        from ISR.predict.predictor import Predictor

        pr_h = Predictor(input_dir=conf['test_sets'][dataset])
        pr_h.get_predictions(gen, conf['weights_paths']['generator'])

    elif session_type == 'training':
        from ISR.train.trainer import Trainer

        hr_patch_size = lr_patch_size * scale
        if conf['default']['feat_ext']:
            from ISR.models.cut_vgg19 import Cut_VGG19

            out_layers = conf['feat_extr']['vgg19']['layers_to_extract']
            f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers)
        else:
            f_ext = None

        if conf['default']['discriminator']:
            from ISR.models.discriminator import Discriminator

            discr = Discriminator(patch_size=hr_patch_size, kernel_size=3)
        else:
            discr = None

        trainer = Trainer(
            generator=gen,
            discriminator=discr,
            feature_extractor=f_ext,
            lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'],
            hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'],
            lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'],
            hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'],
            loss_weights=conf['loss_weights'],
            dataname=conf['training_sets'][dataset]['data_name'],
            logs_dir=conf['dirs']['logs'],
            weights_dir=conf['dirs']['weights'],
            weights_generator=conf['weights_paths']['generator'],
            weights_discriminator=conf['weights_paths']['discriminator'],
            n_validation=conf['session'][session_type]['n_validation_samples'],
            lr_decay_frequency=conf['session'][session_type]['lr_decay_frequency'],
            lr_decay_factor=conf['session'][session_type]['lr_decay_factor'],
            T=0.01,
        )
        trainer.train(
            epochs=conf['session'][session_type]['epochs'],
            steps_per_epoch=conf['session'][session_type]['steps_per_epoch'],
            batch_size=conf['session'][session_type]['batch_size'],
        )

    else:
        logger.error('Invalid choice.')