コード例 #1
0
    def setUpClass(cls):
        cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.weights_path = {
            'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }
        cls.temp_data = Path('tests/temporary_test_data')

        cls.not_matching_hr = cls.temp_data / 'not_matching_hr'
        cls.not_matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.not_matching_hr / item).touch()

        cls.not_matching_lr = cls.temp_data / 'not_matching_lr'
        cls.not_matching_lr.mkdir(parents=True)
        for item in ['data1.png']:
            (cls.not_matching_lr / item).touch()

        cls.matching_hr = cls.temp_data / 'matching_hr'
        cls.matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.matching_hr / item).touch()

        cls.matching_lr = cls.temp_data / 'matching_lr'
        cls.matching_lr.mkdir(parents=True)
        for item in ['data1.png', 'data0.jpeg']:
            (cls.matching_lr / item).touch()

        with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
            cls.trainer = Trainer(
                generator=cls.RRDN,
                discriminator=cls.discr,
                feature_extractor=cls.f_ext,
                lr_train_dir=str(cls.matching_lr),
                hr_train_dir=str(cls.matching_hr),
                lr_valid_dir=str(cls.matching_lr),
                hr_valid_dir=str(cls.matching_hr),
                learning_rate={'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 5},
                log_dirs={
                    'logs': './tests/temporary_test_data/logs',
                    'weights': './tests/temporary_test_data/weights',
                },
                dataname='TEST',
                weights_generator=None,
                weights_discriminator=None,
                n_validation=2,
                flatness={'min': 0.01, 'max': 0.3, 'increase': 0.01, 'increase_frequency': 5},
                adam_optimizer={'beta1': 0.9, 'beta2': 0.999, 'epsilon': None},
                losses={'generator': 'mae', 'discriminator': 'mse', 'feature_extractor': 'mse'},
                loss_weights={'generator': 1.0, 'discriminator': 1.0, 'feature_extractor': 0.5},
            )
コード例 #2
0
    def setUpClass(cls):
        cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.weights_path = {
            'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }

        def fake_folders(kind):
            if kind['matching'] == False:
                if kind['res'] == 'hr':
                    return ['data2.gif', 'data1.png', 'data0.jpeg']
                elif kind['res'] == 'lr':
                    return ['data1.png']
                else:
                    raise
            if kind['matching'] == True:
                if kind['res'] == 'hr':
                    return ['data2.gif', 'data1.png', 'data0.jpeg']
                elif kind['res'] == 'lr':
                    return ['data1.png', 'data0.jpeg']
                else:
                    raise

        with patch('os.listdir', side_effect=fake_folders):
            with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
                cls.trainer = Trainer(
                    generator=cls.RRDN,
                    discriminator=cls.discr,
                    feature_extractor=cls.f_ext,
                    lr_train_dir={'res': 'lr', 'matching': True},
                    hr_train_dir={'res': 'hr', 'matching': True},
                    lr_valid_dir={'res': 'lr', 'matching': True},
                    hr_valid_dir={'res': 'hr', 'matching': True},
                    learning_rate=0.0004,
                    loss_weights={'MSE': 1.0, 'discriminator': 1.0, 'feat_extr': 1.0},
                    logs_dir='./tests/temporary_test_data/logs',
                    weights_dir='./tests/temporary_test_data/weights',
                    dataname='TEST',
                    weights_generator=None,
                    weights_discriminator=None,
                    n_validation=2,
                    lr_decay_factor=0.5,
                    lr_decay_frequency=5,
                    T=0.01,
                )
コード例 #3
0
    def setUpClass(cls):
        cls.setup = yaml.load(
            open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'],
                        patch_size=cls.setup['patch_size'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2,
                              layers_to_extract=[1, 2])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.weights_path = {
            'generator':
            os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator':
            os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }
        cls.temp_data = Path('tests/temporary_test_data')

        cls.not_matching_hr = cls.temp_data / 'not_matching_hr'
        cls.not_matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.not_matching_hr / item).touch()

        cls.not_matching_lr = cls.temp_data / 'not_matching_lr'
        cls.not_matching_lr.mkdir(parents=True)
        for item in ['data1.png']:
            (cls.not_matching_lr / item).touch()

        cls.matching_hr = cls.temp_data / 'matching_hr'
        cls.matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.matching_hr / item).touch()

        cls.matching_lr = cls.temp_data / 'matching_lr'
        cls.matching_lr.mkdir(parents=True)
        for item in ['data1.png', 'data0.jpeg']:
            (cls.matching_lr / item).touch()

        with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                   return_value=True):
            cls.trainer = Trainer(
                generator=cls.RRDN,
                discriminator=cls.discr,
                feature_extractor=cls.f_ext,
                lr_train_dir=str(cls.matching_lr),
                hr_train_dir=str(cls.matching_hr),
                lr_valid_dir=str(cls.matching_lr),
                hr_valid_dir=str(cls.matching_hr),
                learning_rate=0.0004,
                loss_weights={
                    'MSE': 1.0,
                    'discriminator': 1.0,
                    'feat_extr': 1.0
                },
                logs_dir='./tests/temporary_test_data/logs',
                weights_dir='./tests/temporary_test_data/weights',
                dataname='TEST',
                weights_generator=None,
                weights_discriminator=None,
                n_validation=2,
                lr_decay_factor=0.5,
                lr_decay_frequency=5,
                T=0.01,
            )
コード例 #4
0
def run(config_file, default=False, training=False, prediction=False):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    logger = get_logger(__name__)
    session_type, generator, conf, dataset = setup(config_file, default, training, prediction)

    lr_patch_size = conf['session'][session_type]['patch_size']
    scale = conf['generators'][generator]['x']

    module = _get_module(generator)
    gen = module.make_model(conf['generators'][generator], lr_patch_size)
    if session_type == 'prediction':
        from ISR.predict.predictor import Predictor

        pr_h = Predictor(input_dir=conf['test_sets'][dataset])
        pr_h.get_predictions(gen, conf['weights_paths']['generator'])

    elif session_type == 'training':
        from ISR.train.trainer import Trainer

        hr_patch_size = lr_patch_size * scale
        if conf['default']['feat_ext']:
            from ISR.models.cut_vgg19 import Cut_VGG19

            out_layers = conf['feat_extr']['vgg19']['layers_to_extract']
            f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers)
        else:
            f_ext = None

        if conf['default']['discriminator']:
            from ISR.models.discriminator import Discriminator

            discr = Discriminator(patch_size=hr_patch_size, kernel_size=3)
        else:
            discr = None

        trainer = Trainer(
            generator=gen,
            discriminator=discr,
            feature_extractor=f_ext,
            lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'],
            hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'],
            lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'],
            hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'],
            loss_weights=conf['loss_weights'],
            dataname=conf['training_sets'][dataset]['data_name'],
            logs_dir=conf['dirs']['logs'],
            weights_dir=conf['dirs']['weights'],
            weights_generator=conf['weights_paths']['generator'],
            weights_discriminator=conf['weights_paths']['discriminator'],
            n_validation=conf['session'][session_type]['n_validation_samples'],
            lr_decay_frequency=conf['session'][session_type]['lr_decay_frequency'],
            lr_decay_factor=conf['session'][session_type]['lr_decay_factor'],
            T=0.01,
        )
        trainer.train(
            epochs=conf['session'][session_type]['epochs'],
            steps_per_epoch=conf['session'][session_type]['steps_per_epoch'],
            batch_size=conf['session'][session_type]['batch_size'],
        )

    else:
        logger.error('Invalid choice.')