def setUpClass(cls):
     cls.setup = yaml.load(
         open(os.path.join('tests', 'data', 'config.yml'), 'r'))
     cls.RRDN = RRDN(arch_params=cls.setup['rrdn'],
                     patch_size=cls.setup['patch_size'])
     cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'],
                           layers_to_extract=[1, 2])
     cls.discr = Discriminator(patch_size=cls.setup['patch_size'])
     cls.weights_path = {
         'generator':
         os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'),
         'discriminator':
         os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'),
     }
     cls.TH = TrainerHelper(
         generator=cls.RRDN,
         weights_dir=cls.setup['weights_dir'],
         logs_dir=cls.setup['log_dir'],
         lr_train_dir=cls.setup['lr_input'],
         feature_extractor=cls.f_ext,
         discriminator=cls.discr,
         dataname='TEST',
         pretrained_weights_path={},
         fallback_save_every_n_epochs=2,
     )
Esempio n. 2
0
 def setUpClass(cls):
     cls.setup = yaml.load(Path('./tests/data/config.yml').read_text())
     cls.RRDN = RRDN(arch_params=cls.setup['rrdn'],
                     patch_size=cls.setup['patch_size'])
     cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'],
                           layers_to_extract=[1, 2])
     cls.discr = Discriminator(patch_size=cls.setup['patch_size'])
     cls.weights_path = {
         'generator':
         Path(cls.setup['weights_dir']) / 'test_gen_weights.hdf5',
         'discriminator':
         Path(cls.setup['weights_dir']) / 'test_dis_weights.hdf5',
     }
     cls.TH = TrainerHelper(
         generator=cls.RRDN,
         weights_dir=cls.setup['weights_dir'],
         logs_dir=cls.setup['log_dir'],
         lr_train_dir=cls.setup['lr_input'],
         feature_extractor=cls.f_ext,
         discriminator=cls.discr,
         dataname='TEST',
         weights_generator='',
         weights_discriminator='',
         fallback_save_every_n_epochs=2,
     )
     cls.TH.session_id = '0000'
     cls.TH.logger.setLevel(50)
Esempio n. 3
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        loss_weights={
            'generator': 1.0,
            'discriminator': 0.003,
            'feature_extractor': 1 / 12
        },
        log_dirs={
            'logs': 'logs',
            'weights': 'weights'
        },
        fallback_save_every_n_epochs=2,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        flatness={
            'min': 0.0,
            'increase_frequency': None,
            'increase': 0.0,
            'max': 0.0
        },
        learning_rate={
            'initial_value': 0.0004,
            'decay_frequency': 100,
            'decay_factor': 0.5
        },
        adam_optimizer={
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': None
        },
        losses={
            'generator': 'mae',
            'discriminator': 'binary_crossentropy',
            'feature_extractor': 'mse',
        },
        metrics={'generator': 'PSNR_Y'},
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.adam_optimizer = adam_optimizer
        self.dataname = dataname
        self.flatness = flatness
        self.n_validation = n_validation
        self.losses = losses
        self.log_dirs = log_dirs
        self.metrics = metrics
        if self.metrics['generator'] == 'PSNR_Y':
            self.metrics['generator'] = PSNR_Y
        elif self.metrics['generator'] == 'PSNR':
            self.metrics['generator'] = PSNR
        self._parameters_sanity_check()
        self.model = self._combine_networks()

        self.settings = {}
        self.settings['training_parameters'] = locals()
        self.settings['training_parameters'][
            'lr_patch_size'] = self.lr_patch_size
        self.settings = self.update_training_config(self.settings)

        self.logger = get_logger(__name__)

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=log_dirs['weights'],
            logs_dir=log_dirs['logs'],
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
        )
Esempio n. 4
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
        fallback_save_every_n_epochs=2,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=0.00001,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.dataname = dataname
        self.T = T
        self.n_validation = n_validation

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.0,
        )
        self.logger = get_logger(__name__)

        self.settings = self.get_training_config()
Esempio n. 5
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.best_metrics = {}
        self.pretrained_weights_path = {
            'generator': weights_generator,
            'discriminator': weights_discriminator,
        }
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            pretrained_weights_path=self.pretrained_weights_path,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.01,
        )
        self.logger = get_logger(__name__)