Пример #1
0
 def setUpClass(cls):
     cls.setup = yaml.load(Path('./tests/data/config.yml').read_text())
     cls.RRDN = RRDN(arch_params=cls.setup['rrdn'],
                     patch_size=cls.setup['patch_size'])
     cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'],
                           layers_to_extract=[1, 2])
     cls.discr = Discriminator(patch_size=cls.setup['patch_size'])
     cls.weights_path = {
         'generator':
         Path(cls.setup['weights_dir']) / 'test_gen_weights.hdf5',
         'discriminator':
         Path(cls.setup['weights_dir']) / 'test_dis_weights.hdf5',
     }
     cls.TH = TrainerHelper(
         generator=cls.RRDN,
         weights_dir=cls.setup['weights_dir'],
         logs_dir=cls.setup['log_dir'],
         lr_train_dir=cls.setup['lr_input'],
         feature_extractor=cls.f_ext,
         discriminator=cls.discr,
         dataname='TEST',
         weights_generator='',
         weights_discriminator='',
         fallback_save_every_n_epochs=2,
     )
     cls.TH.session_id = '0000'
     cls.TH.logger.setLevel(50)
 def setUpClass(cls):
     cls.setup = yaml.load(
         open(os.path.join('tests', 'data', 'config.yml'), 'r'))
     cls.RRDN = RRDN(arch_params=cls.setup['rrdn'],
                     patch_size=cls.setup['patch_size'])
     cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'],
                           layers_to_extract=[1, 2])
     cls.discr = Discriminator(patch_size=cls.setup['patch_size'])
     cls.weights_path = {
         'generator':
         os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
         'discriminator':
         os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
     }
     cls.TH = TrainerHelper(
         generator=cls.RRDN,
         weights_dir=cls.setup['weights_dir'],
         logs_dir=cls.setup['log_dir'],
         lr_train_dir=cls.setup['lr_input'],
         feature_extractor=cls.f_ext,
         discriminator=cls.discr,
         dataname='TEST',
         pretrained_weights_path={},
         fallback_save_every_n_epochs=2,
     )
Пример #3
0
    def setUpClass(cls):
        cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.weights_path = {
            'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }
        cls.temp_data = Path('tests/temporary_test_data')

        cls.not_matching_hr = cls.temp_data / 'not_matching_hr'
        cls.not_matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.not_matching_hr / item).touch()

        cls.not_matching_lr = cls.temp_data / 'not_matching_lr'
        cls.not_matching_lr.mkdir(parents=True)
        for item in ['data1.png']:
            (cls.not_matching_lr / item).touch()

        cls.matching_hr = cls.temp_data / 'matching_hr'
        cls.matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.matching_hr / item).touch()

        cls.matching_lr = cls.temp_data / 'matching_lr'
        cls.matching_lr.mkdir(parents=True)
        for item in ['data1.png', 'data0.jpeg']:
            (cls.matching_lr / item).touch()

        with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
            cls.trainer = Trainer(
                generator=cls.RRDN,
                discriminator=cls.discr,
                feature_extractor=cls.f_ext,
                lr_train_dir=str(cls.matching_lr),
                hr_train_dir=str(cls.matching_hr),
                lr_valid_dir=str(cls.matching_lr),
                hr_valid_dir=str(cls.matching_hr),
                learning_rate={'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 5},
                log_dirs={
                    'logs': './tests/temporary_test_data/logs',
                    'weights': './tests/temporary_test_data/weights',
                },
                dataname='TEST',
                weights_generator=None,
                weights_discriminator=None,
                n_validation=2,
                flatness={'min': 0.01, 'max': 0.3, 'increase': 0.01, 'increase_frequency': 5},
                adam_optimizer={'beta1': 0.9, 'beta2': 0.999, 'epsilon': None},
                losses={'generator': 'mae', 'discriminator': 'mse', 'feature_extractor': 'mse'},
                loss_weights={'generator': 1.0, 'discriminator': 1.0, 'feature_extractor': 0.5},
            )
    def setUpClass(cls):
        cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.weights_path = {
            'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }

        def fake_folders(kind):
            if kind['matching'] == False:
                if kind['res'] == 'hr':
                    return ['data2.gif', 'data1.png', 'data0.jpeg']
                elif kind['res'] == 'lr':
                    return ['data1.png']
                else:
                    raise
            if kind['matching'] == True:
                if kind['res'] == 'hr':
                    return ['data2.gif', 'data1.png', 'data0.jpeg']
                elif kind['res'] == 'lr':
                    return ['data1.png', 'data0.jpeg']
                else:
                    raise

        with patch('os.listdir', side_effect=fake_folders):
            with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
                cls.trainer = Trainer(
                    generator=cls.RRDN,
                    discriminator=cls.discr,
                    feature_extractor=cls.f_ext,
                    lr_train_dir={'res': 'lr', 'matching': True},
                    hr_train_dir={'res': 'hr', 'matching': True},
                    lr_valid_dir={'res': 'lr', 'matching': True},
                    hr_valid_dir={'res': 'hr', 'matching': True},
                    learning_rate=0.0004,
                    loss_weights={'MSE': 1.0, 'discriminator': 1.0, 'feat_extr': 1.0},
                    logs_dir='./tests/temporary_test_data/logs',
                    weights_dir='./tests/temporary_test_data/weights',
                    dataname='TEST',
                    weights_generator=None,
                    weights_discriminator=None,
                    n_validation=2,
                    lr_decay_factor=0.5,
                    lr_decay_frequency=5,
                    T=0.01,
                )
Пример #5
0
    def setUpClass(cls):
        cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.weights_path = {
            'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }
        cls.hr_shape = (cls.setup['patch_size'] * 2,) * 2 + (3,)

        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size'])
        cls.RRDN.model.compile(optimizer=Adam(), loss=['mse'])
        cls.RDN = RDN(arch_params=cls.setup['rdn'], patch_size=cls.setup['patch_size'])
        cls.RDN.model.compile(optimizer=Adam(), loss=['mse'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2])
        cls.f_ext.model.compile(optimizer=Adam(), loss=['mse', 'mse'])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.discr.model.compile(optimizer=Adam(), loss=['mse'])
Пример #6
0
    def setUpClass(cls):
        cls.setup = yaml.load(
            open(os.path.join('tests', 'data', 'config.yml'), 'r'))
        cls.RRDN = RRDN(arch_params=cls.setup['rrdn'],
                        patch_size=cls.setup['patch_size'])
        cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2,
                              layers_to_extract=[1, 2])
        cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2)
        cls.weights_path = {
            'generator':
            os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
            'discriminator':
            os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
        }
        cls.temp_data = Path('tests/temporary_test_data')

        cls.not_matching_hr = cls.temp_data / 'not_matching_hr'
        cls.not_matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.not_matching_hr / item).touch()

        cls.not_matching_lr = cls.temp_data / 'not_matching_lr'
        cls.not_matching_lr.mkdir(parents=True)
        for item in ['data1.png']:
            (cls.not_matching_lr / item).touch()

        cls.matching_hr = cls.temp_data / 'matching_hr'
        cls.matching_hr.mkdir(parents=True)
        for item in ['data2.gif', 'data1.png', 'data0.jpeg']:
            (cls.matching_hr / item).touch()

        cls.matching_lr = cls.temp_data / 'matching_lr'
        cls.matching_lr.mkdir(parents=True)
        for item in ['data1.png', 'data0.jpeg']:
            (cls.matching_lr / item).touch()

        with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                   return_value=True):
            cls.trainer = Trainer(
                generator=cls.RRDN,
                discriminator=cls.discr,
                feature_extractor=cls.f_ext,
                lr_train_dir=str(cls.matching_lr),
                hr_train_dir=str(cls.matching_hr),
                lr_valid_dir=str(cls.matching_lr),
                hr_valid_dir=str(cls.matching_hr),
                learning_rate=0.0004,
                loss_weights={
                    'MSE': 1.0,
                    'discriminator': 1.0,
                    'feat_extr': 1.0
                },
                logs_dir='./tests/temporary_test_data/logs',
                weights_dir='./tests/temporary_test_data/weights',
                dataname='TEST',
                weights_generator=None,
                weights_discriminator=None,
                n_validation=2,
                lr_decay_factor=0.5,
                lr_decay_frequency=5,
                T=0.01,
            )