def test_get_validation_batches_invalid_number_of_samples(self): patch_size = 3 with patch('os.listdir', side_effect=self.fake_folders): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir={ 'res': 'lr', 'matching': True }, hr_dir={ 'res': 'hr', 'matching': True }, patch_size=patch_size, scale=2, n_validation_samples=None, ) with patch('imageio.imread', side_effect=self.image_getter): with patch('os.path.join', side_effect=self.path_giver): try: with patch('raise', None): batch = DH.get_validation_batches(batch_size=5) except: self.assertTrue(True) else: self.assertTrue(False)
def test_get_validation_batches_valid_request(self): patch_size = 3 with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): with patch('os.listdir', side_effect=self.fake_folders): DH = DataHandler( lr_dir={ 'res': 'lr', 'matching': True }, hr_dir={ 'res': 'hr', 'matching': True }, patch_size=patch_size, scale=2, n_validation_samples=2, ) with patch('imageio.imread', side_effect=self.image_getter): with patch('os.path.join', side_effect=self.path_giver): batch = DH.get_validation_batches(batch_size=12) self.assertTrue(len(batch) == 2) self.assertTrue(type(batch) is list) self.assertTrue(type(batch[0]) is dict) self.assertTrue(batch[0]['hr'].shape == (12, patch_size * 2, patch_size * 2, 3)) self.assertTrue(batch[0]['lr'].shape == (12, patch_size, patch_size, 3)) self.assertTrue(batch[1]['hr'].shape == (12, patch_size * 2, patch_size * 2, 3)) self.assertTrue(batch[1]['lr'].shape == (12, patch_size, patch_size, 3))
def test_get_batch_shape_and_diversity(self): patch_size = 3 with patch('os.listdir', side_effect=self.fake_folders): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir={ 'res': 'lr', 'matching': True }, hr_dir={ 'res': 'hr', 'matching': True }, patch_size=patch_size, scale=2, n_validation_samples=None, ) with patch('imageio.imread', side_effect=self.image_getter): with patch('os.path.join', side_effect=self.path_giver): batch = DH.get_batch(batch_size=5) self.assertTrue(type(batch) is dict) self.assertTrue(batch['hr'].shape == (5, patch_size * 2, patch_size * 2, 3)) self.assertTrue(batch['lr'].shape == (5, patch_size, patch_size, 3)) self.assertTrue( np.any([ batch['lr'][0] != batch['lr'][1], batch['lr'][1] != batch['lr'][2], batch['lr'][2] != batch['lr'][3], batch['lr'][3] != batch['lr'][4], ]))
def test__not_flat_with_non_flat_patch(self): lr_patch = np.random.random((5, 5, 3)) with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir=None, hr_dir=None, patch_size=0, scale=0, n_validation_samples=None ) self.assertTrue(DH._not_flat(lr_patch, flatness=0.00001))
def test__crop_imgs_crops_shapes(self): with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None ) imgs = {'hr': np.random.random((20, 20, 3)), 'lr': np.random.random((10, 10, 3))} crops = DH._crop_imgs(imgs, batch_size=2, flatness=0) self.assertTrue(crops['hr'].shape == (2, 6, 6, 3)) self.assertTrue(crops['lr'].shape == (2, 3, 3, 3))
def test__not_flat_with_flat_patch(self): lr_patch = np.zeros((5, 5, 3)) with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir=None, hr_dir=None, patch_size=0, scale=0, n_validation_samples=None, T=0.01, ) self.assertFalse(DH._not_flat(lr_patch))
def test__transform_batch(self): with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None ) I = np.ones((2, 2)) A = I * 0 B = I * 1 C = I * 2 D = I * 3 image = np.block([[A, B], [C, D]]) t_image_1 = np.block([[D, B], [C, A]]) t_image_2 = np.block([[B, D], [A, C]]) batch = np.array([image, image]) expected = np.array([t_image_1, t_image_2]) self.assertTrue(np.all(DH._transform_batch(batch, [[1, 1], [2, 0]]) == expected))
def test__check_dataset_with_matching_data(self): with patch('os.listdir', side_effect=self.fake_folders): DH = DataHandler( lr_dir={'res': 'lr', 'matching': True}, hr_dir={'res': 'hr', 'matching': True}, patch_size=0, scale=0, n_validation_samples=None, )
def test__make_img_list_non_validation(self): with patch('os.listdir', side_effect=self.fake_folders): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir={'res': 'lr', 'matching': False}, hr_dir={'res': 'hr', 'matching': False}, patch_size=0, scale=0, n_validation_samples=None, ) expected_ls = {'hr': ['data0.jpeg', 'data1.png'], 'lr': ['data1.png']} self.assertTrue(np.all(DH.img_list['hr'] == expected_ls['hr'])) self.assertTrue(np.all(DH.img_list['lr'] == expected_ls['lr']))
def test__check_dataset_with_mismatching_data(self): try: with patch('os.listdir', side_effect=self.fake_folders): DH = DataHandler( lr_dir={'res': 'lr', 'matching': False}, hr_dir={'res': 'hr', 'matching': False}, patch_size=0, scale=0, n_validation_samples=None, ) except: self.assertTrue(True) else: self.assertTrue(False)
def test_get_validation_batches_requesting_more_than_available(self): patch_size = 3 with patch('os.listdir', side_effect=self.fake_folders): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): try: DH = DataHandler( lr_dir={'res': 'lr', 'matching': True}, hr_dir={'res': 'hr', 'matching': True}, patch_size=patch_size, scale=2, n_validation_samples=10, ) except: self.assertTrue(True) else: self.assertTrue(False)
def test__apply_transorm(self): I = np.ones((2, 2)) A = I * 0 B = I * 1 C = I * 2 D = I * 3 image = np.block([[A, B], [C, D]]) with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): DH = DataHandler( lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None, T=0.0, ) transf = [[1, 0], [0, 1], [2, 0], [0, 2], [1, 1], [0, 0]] self.assertTrue( np.all( np.block([[C, A], [D, B]]) == DH._apply_transform( image, transf[0]))) self.assertTrue( np.all( np.block([[C, D], [A, B]]) == DH._apply_transform( image, transf[1]))) self.assertTrue( np.all( np.block([[B, D], [A, C]]) == DH._apply_transform( image, transf[2]))) self.assertTrue( np.all( np.block([[B, A], [D, C]]) == DH._apply_transform( image, transf[3]))) self.assertTrue( np.all( np.block([[D, B], [C, A]]) == DH._apply_transform( image, transf[4]))) self.assertTrue(np.all(image == DH._apply_transform(image, transf[5])))
def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights={ 'generator': 1.0, 'discriminator': 0.003, 'feature_extractor': 1 / 12 }, log_dirs={ 'logs': 'logs', 'weights': 'weights' }, fallback_save_every_n_epochs=2, dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, flatness={ 'min': 0.0, 'increase_frequency': None, 'increase': 0.0, 'max': 0.0 }, learning_rate={ 'initial_value': 0.0004, 'decay_frequency': 100, 'decay_factor': 0.5 }, adam_optimizer={ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': None }, losses={ 'generator': 'mae', 'discriminator': 'binary_crossentropy', 'feature_extractor': 'mse', }, metrics={'generator': 'PSNR_Y'}, ): self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.weights_generator = weights_generator self.weights_discriminator = weights_discriminator self.adam_optimizer = adam_optimizer self.dataname = dataname self.flatness = flatness self.n_validation = n_validation self.losses = losses self.log_dirs = log_dirs self.metrics = metrics if self.metrics['generator'] == 'PSNR_Y': self.metrics['generator'] = PSNR_Y elif self.metrics['generator'] == 'PSNR': self.metrics['generator'] = PSNR self._parameters_sanity_check() self.model = self._combine_networks() self.settings = {} self.settings['training_parameters'] = locals() self.settings['training_parameters'][ 'lr_patch_size'] = self.lr_patch_size self.settings = self.update_training_config(self.settings) self.logger = get_logger(__name__) self.helper = TrainerHelper( generator=self.generator, weights_dir=log_dirs['weights'], logs_dir=log_dirs['logs'], lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, weights_generator=self.weights_generator, weights_discriminator=self.weights_discriminator, fallback_save_every_n_epochs=fallback_save_every_n_epochs, ) self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, )
class Trainer: """Class object to setup and carry the training. Takes as input a generator that produces SR images. Conditionally, also a discriminator network and a feature extractor to build the components of the perceptual loss. Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise carries a regular ISR training. Args: generator: Keras model, the super-scaling, or generator, network. discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss. feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function. lr_train_dir: path to the directory containing the Low-Res images for training. hr_train_dir: path to the directory containing the High-Res images for training. lr_valid_dir: path to the directory containing the Low-Res images for validation. hr_valid_dir: path to the directory containing the High-Res images for validation. learning_rate: float. loss_weights: dictionary, use to weigh the components of the loss function. Contains 'generator' for the generator loss component, and can contain 'discriminator' and 'feature_extractor' for the discriminator and deep features components respectively. logs_dir: path to the directory where the tensorboard logs are saved. weights_dir: path to the directory where the weights are saved. dataname: string, used to identify what dataset is used for the training session. weights_generator: path to the pre-trained generator's weights, for transfer learning. weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning. n_validation:integer, number of validation samples used at training from the validation set. flatness: dictionary. Determines determines the 'flatness' threshold level for the training patches. See the TrainerHelper class for more details. lr_decay_frequency: integer, every how many epochs the learning rate is reduced. lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor. Methods: train: combines the networks and triggers training with the specified settings. """ def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights={ 'generator': 1.0, 'discriminator': 0.003, 'feature_extractor': 1 / 12 }, log_dirs={ 'logs': 'logs', 'weights': 'weights' }, fallback_save_every_n_epochs=2, dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, flatness={ 'min': 0.0, 'increase_frequency': None, 'increase': 0.0, 'max': 0.0 }, learning_rate={ 'initial_value': 0.0004, 'decay_frequency': 100, 'decay_factor': 0.5 }, adam_optimizer={ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': None }, losses={ 'generator': 'mae', 'discriminator': 'binary_crossentropy', 'feature_extractor': 'mse', }, metrics={'generator': 'PSNR_Y'}, ): self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.weights_generator = weights_generator self.weights_discriminator = weights_discriminator self.adam_optimizer = adam_optimizer self.dataname = dataname self.flatness = flatness self.n_validation = n_validation self.losses = losses self.log_dirs = log_dirs self.metrics = metrics if self.metrics['generator'] == 'PSNR_Y': self.metrics['generator'] = PSNR_Y elif self.metrics['generator'] == 'PSNR': self.metrics['generator'] = PSNR self._parameters_sanity_check() self.model = self._combine_networks() self.settings = {} self.settings['training_parameters'] = locals() self.settings['training_parameters'][ 'lr_patch_size'] = self.lr_patch_size self.settings = self.update_training_config(self.settings) self.logger = get_logger(__name__) self.helper = TrainerHelper( generator=self.generator, weights_dir=log_dirs['weights'], logs_dir=log_dirs['logs'], lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, weights_generator=self.weights_generator, weights_discriminator=self.weights_discriminator, fallback_save_every_n_epochs=fallback_save_every_n_epochs, ) self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, ) def _parameters_sanity_check(self): """ Parameteres sanity check. """ if self.discriminator: assert self.lr_patch_size * self.scale == self.discriminator.patch_size self.adam_optimizer if self.feature_extractor: assert self.lr_patch_size * self.scale == self.feature_extractor.patch_size check_parameter_keys( self.learning_rate, needed_keys=['initial_value'], optional_keys=['decay_factor', 'decay_frequency'], default_value=None, ) check_parameter_keys( self.flatness, needed_keys=[], optional_keys=['min', 'increase_frequency', 'increase', 'max'], default_value=0.0, ) check_parameter_keys( self.adam_optimizer, needed_keys=['beta1', 'beta2'], optional_keys=['epsilon'], default_value=None, ) check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights']) def _combine_networks(self): """ Constructs the combined model which contains the generator network, as well as discriminator and geature extractor, if any are defined. """ lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, )) sr = self.generator.model(lr) outputs = [sr] losses = [self.losses['generator']] loss_weights = [self.loss_weights['generator']] if self.discriminator: self.discriminator.model.trainable = False validity = self.discriminator.model(sr) outputs.append(validity) losses.append(self.losses['discriminator']) loss_weights.append(self.loss_weights['discriminator']) if self.feature_extractor: self.feature_extractor.model.trainable = False sr_feats = self.feature_extractor.model(sr) outputs.extend([*sr_feats]) losses.extend([self.losses['feature_extractor']] * len(sr_feats)) loss_weights.extend( [self.loss_weights['feature_extractor'] / len(sr_feats)] * len(sr_feats)) combined = Model(inputs=lr, outputs=outputs) # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows optimizer = Adam( beta_1=self.adam_optimizer['beta1'], beta_2=self.adam_optimizer['beta2'], lr=self.learning_rate['initial_value'], epsilon=self.adam_optimizer['epsilon'], ) combined.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer, metrics=self.metrics) return combined def _lr_scheduler(self, epoch): """ Scheduler for the learning rate updates. """ n_decays = epoch // self.learning_rate['decay_frequency'] lr = self.learning_rate['initial_value'] * ( self.learning_rate['decay_factor']**n_decays) # no lr below minimum control 10e-7 return max(1e-7, lr) def _flatness_scheduler(self, epoch): if self.flatness['increase']: n_increases = epoch // self.flatness['increase_frequency'] else: return self.flatness['min'] f = self.flatness['min'] + n_increases * self.flatness['increase'] return min(self.flatness['max'], f) def _load_weights(self): """ Loads the pretrained weights from the given path, if any is provided. If a discriminator is defined, does the same. """ if self.weights_generator: self.model.get_layer('generator').load_weights( self.weights_generator) if self.discriminator: if self.weights_discriminator: self.model.get_layer('discriminator').load_weights( self.weights_discriminator) self.discriminator.model.load_weights( self.weights_discriminator) def _format_losses(self, prefix, losses, model_metrics): """ Creates a dictionary for tensorboard tracking. """ return dict(zip([prefix + m for m in model_metrics], losses)) def update_training_config(self, settings): """ Summarizes training setting. """ _ = settings['training_parameters'].pop('weights_generator') _ = settings['training_parameters'].pop('self') _ = settings['training_parameters'].pop('generator') _ = settings['training_parameters'].pop('discriminator') _ = settings['training_parameters'].pop('feature_extractor') settings['generator'] = {} settings['generator']['name'] = self.generator.name settings['generator']['parameters'] = self.generator.params settings['generator']['weights_generator'] = self.weights_generator _ = settings['training_parameters'].pop('weights_discriminator') if self.discriminator: settings['discriminator'] = {} settings['discriminator']['name'] = self.discriminator.name settings['discriminator'][ 'weights_discriminator'] = self.weights_discriminator else: settings['discriminator'] = None if self.discriminator: settings['feature_extractor'] = {} settings['feature_extractor']['name'] = self.feature_extractor.name settings['feature_extractor'][ 'layers'] = self.feature_extractor.layers_to_extract else: settings['feature_extractor'] = None return settings def train(self, epochs, steps_per_epoch, batch_size, monitored_metrics): """ Carries on the training for the given number of epochs. Sends the losses to Tensorboard. Args: epochs: how many epochs to train for. steps_per_epoch: how many batches epoch. batch_size: amount of images per batch. monitored_metrics: dictionary, the keys are the metrics that are monitored for the weights saving logic. The values are the mode that trigger the weights saving ('min' vs 'max'). """ self.settings['training_parameters'][ 'steps_per_epoch'] = steps_per_epoch self.settings['training_parameters']['batch_size'] = batch_size starting_epoch = self.helper.initialize_training( self) # load_weights, creates folders, creates basename self.tensorboard = TensorBoard( log_dir=self.helper.callback_paths['logs']) self.tensorboard.set_model(self.model) # validation data validation_set = self.valid_dh.get_validation_set(batch_size) y_validation = [validation_set['hr']] if self.discriminator: discr_out_shape = list( self.discriminator.model.outputs[0].shape)[1:4] valid = np.ones([batch_size] + discr_out_shape) fake = np.zeros([batch_size] + discr_out_shape) validation_valid = np.ones([len(validation_set['hr'])] + discr_out_shape) y_validation.append(validation_valid) if self.feature_extractor: validation_feats = self.feature_extractor.model.predict( validation_set['hr']) y_validation.extend([*validation_feats]) for epoch in range(starting_epoch, epochs): self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch, tot_eps=epochs)) K.set_value(self.model.optimizer.lr, self._lr_scheduler(epoch=epoch)) self.logger.info('Current learning rate: {}'.format( K.eval(self.model.optimizer.lr))) flatness = self._flatness_scheduler(epoch) if flatness: self.logger.info( 'Current flatness treshold: {}'.format(flatness)) epoch_start = time() for step in tqdm(range(steps_per_epoch)): batch = self.train_dh.get_batch(batch_size, flatness=flatness) y_train = [batch['hr']] training_losses = {} ## Discriminator training if self.discriminator: sr = self.generator.model.predict(batch['lr']) d_loss_real = self.discriminator.model.train_on_batch( batch['hr'], valid) d_loss_fake = self.discriminator.model.train_on_batch( sr, fake) d_loss_fake = self._format_losses( 'train_d_fake_', d_loss_fake, self.discriminator.model.metrics_names) d_loss_real = self._format_losses( 'train_d_real_', d_loss_real, self.discriminator.model.metrics_names) training_losses.update(d_loss_real) training_losses.update(d_loss_fake) y_train.append(valid) ## Generator training if self.feature_extractor: hr_feats = self.feature_extractor.model.predict( batch['hr']) y_train.extend([*hr_feats]) model_losses = self.model.train_on_batch(batch['lr'], y_train) model_losses = self._format_losses('train_', model_losses, self.model.metrics_names) training_losses.update(model_losses) self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step, training_losses) self.logger.debug('Losses at step {s}:\n {l}'.format( s=step, l=training_losses)) elapsed_time = time() - epoch_start self.logger.info('Epoch {} took {:10.1f}s'.format( epoch, elapsed_time)) validation_losses = self.model.evaluate(validation_set['lr'], y_validation, batch_size=batch_size) validation_losses = self._format_losses('val_', validation_losses, self.model.metrics_names) if epoch == starting_epoch: remove_metrics = [] for metric in monitored_metrics: if (metric not in training_losses) and ( metric not in validation_losses): msg = ' '.join([ metric, 'is NOT among the model metrics, removing it.' ]) self.logger.error(msg) remove_metrics.append(metric) for metric in remove_metrics: _ = monitored_metrics.pop(metric) # should average train metrics end_losses = {} end_losses.update(validation_losses) end_losses.update(training_losses) self.helper.on_epoch_end( epoch=epoch, losses=end_losses, generator=self.model.get_layer('generator'), discriminator=self.discriminator, metrics=monitored_metrics, ) self.tensorboard.on_epoch_end(epoch, validation_losses) self.tensorboard.on_train_end(None)
def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, learning_rate=0.0004, loss_weights={'MSE': 1.0}, logs_dir='logs', weights_dir='weights', dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, T=0.01, lr_decay_frequency=100, lr_decay_factor=0.5, fallback_save_every_n_epochs=2, beta_1=0.9, beta_2=0.999, epsilon=0.00001, ): if discriminator: assert generator.patch_size * generator.scale == discriminator.patch_size if feature_extractor: assert generator.patch_size * generator.scale == feature_extractor.patch_size self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.weights_generator = weights_generator self.weights_discriminator = weights_discriminator self.lr_decay_factor = lr_decay_factor self.lr_decay_frequency = lr_decay_frequency self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon self.dataname = dataname self.T = T self.n_validation = n_validation self.helper = TrainerHelper( generator=self.generator, weights_dir=weights_dir, logs_dir=logs_dir, lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, weights_generator=self.weights_generator, weights_discriminator=self.weights_discriminator, fallback_save_every_n_epochs=fallback_save_every_n_epochs, ) self.model = self._combine_networks() self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, T=T, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, T=0.0, ) self.logger = get_logger(__name__) self.settings = self.get_training_config()
def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, learning_rate=0.0004, loss_weights={'MSE': 1.0}, logs_dir='logs', weights_dir='weights', dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, T=0.01, lr_decay_frequency=100, lr_decay_factor=0.5, ): if discriminator: assert generator.patch_size * generator.scale == discriminator.patch_size if feature_extractor: assert generator.patch_size * generator.scale == feature_extractor.patch_size self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.best_metrics = {} self.pretrained_weights_path = { 'generator': weights_generator, 'discriminator': weights_discriminator, } self.lr_decay_factor = lr_decay_factor self.lr_decay_frequency = lr_decay_frequency self.helper = TrainerHelper( generator=self.generator, weights_dir=weights_dir, logs_dir=logs_dir, lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, pretrained_weights_path=self.pretrained_weights_path, ) self.model = self._combine_networks() self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, T=T, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, T=0.01, ) self.logger = get_logger(__name__)
class Trainer: """Class object to setup and carry the training. Takes as input a generator that produces SR images. Conditionally, also a discriminator network and a feature extractor to build the components of the perceptual loss. Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise carries a regular ISR training. Args: generator: Keras model, the super-scaling, or generator, network. discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss. feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function. lr_train_dir: path to the directory containing the Low-Res images for training. hr_train_dir: path to the directory containing the High-Res images for training. lr_valid_dir: path to the directory containing the Low-Res images for validation. hr_valid_dir: path to the directory containing the High-Res images for validation. learning_rate: float. loss_weights: dictionary, use to weigh the components of the loss function. Contains 'MSE' for the MSE loss component, and can contain 'discriminator' and 'feat_extr' for the discriminator and deep features components respectively. logs_dir: path to the directory where the tensorboard logs are saved. weights_dir: path to the directory where the weights are saved. dataname: string, used to identify what dataset is used for the training session. weights_generator: path to the pre-trained generator's weights, for transfer learning. weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning. n_validation:integer, number of validation samples used at training from the validation set. T: 0 < float <1, determines the 'flatness' threshold level for the training patches. See the TrainerHelper class for more details. lr_decay_frequency: integer, every how many epochs the learning rate is reduced. lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor. Methods: train: combines the networks and triggers training with the specified settings. """ def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, learning_rate=0.0004, loss_weights={'MSE': 1.0}, logs_dir='logs', weights_dir='weights', dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, T=0.01, lr_decay_frequency=100, lr_decay_factor=0.5, ): if discriminator: assert generator.patch_size * generator.scale == discriminator.patch_size if feature_extractor: assert generator.patch_size * generator.scale == feature_extractor.patch_size self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.best_metrics = {} self.pretrained_weights_path = { 'generator': weights_generator, 'discriminator': weights_discriminator, } self.lr_decay_factor = lr_decay_factor self.lr_decay_frequency = lr_decay_frequency self.helper = TrainerHelper( generator=self.generator, weights_dir=weights_dir, logs_dir=logs_dir, lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, pretrained_weights_path=self.pretrained_weights_path, ) self.model = self._combine_networks() self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, T=T, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, T=0.01, ) self.logger = get_logger(__name__) def _combine_networks(self): """ Constructs the combined model which contains the generator network, as well as discriminator and geature extractor, if any are defined. """ lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, )) sr = self.generator.model(lr) outputs = [sr] losses = ['mse'] loss_weights = [self.loss_weights['MSE']] if self.discriminator: self.discriminator.model.trainable = False validity = self.discriminator.model(sr) outputs.append(validity) losses.append('binary_crossentropy') loss_weights.append(self.loss_weights['discriminator']) if self.feature_extractor: self.feature_extractor.model.trainable = False sr_feats = self.feature_extractor.model(sr) outputs.extend([*sr_feats]) losses.extend(['mse'] * len(sr_feats)) loss_weights.extend( [self.loss_weights['feat_extr'] / len(sr_feats)] * len(sr_feats)) combined = Model(inputs=lr, outputs=outputs) # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows optimizer = Adam(epsilon=0.0000001) combined.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer, metrics={'generator': PSNR}) return combined def _lr_scheduler(self, epoch): """ Scheduler for the learning rate updates. """ n_decays = epoch // self.lr_decay_frequency # no lr below minimum control 10e-6 return max(1e-6, self.learning_rate * (self.lr_decay_factor**n_decays)) def _load_weights(self): """ Loads the pretrained weights from the given path, if any is provided. If a discriminator is defined, does the same. """ gen_w = self.pretrained_weights_path['generator'] if gen_w: self.model.get_layer('generator').load_weights(gen_w) if self.discriminator: dis_w = self.pretrained_weights_path['discriminator'] if dis_w: self.model.get_layer('discriminator').load_weights(dis_w) self.discriminator.model.load_weights(dis_w) def train(self, epochs, steps_per_epoch, batch_size): """ Carries on the training for the given number of epochs. Sends the losses to Tensorboard. """ starting_epoch = self.helper.initialize_training( self) # load_weights, creates folders, creates basename self.tensorboard = TensorBoard( log_dir=self.helper.callback_paths['logs']) self.tensorboard.set_model(self.model) # validation data validation_set = self.valid_dh.get_validation_set(batch_size) y_validation = [validation_set['hr']] if self.discriminator: discr_out_shape = list( self.discriminator.model.outputs[0].shape)[1:4] valid = np.ones([batch_size] + discr_out_shape) fake = np.zeros([batch_size] + discr_out_shape) validation_valid = np.ones([len(validation_set['hr'])] + discr_out_shape) y_validation.append(validation_valid) if self.feature_extractor: validation_feats = self.feature_extractor.model.predict( validation_set['hr']) y_validation.extend([*validation_feats]) for epoch in range(starting_epoch, epochs): self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch, tot_eps=epochs)) K.set_value(self.model.optimizer.lr, self._lr_scheduler(epoch=epoch)) self.logger.info('Current learning rate: {}'.format( K.eval(self.model.optimizer.lr))) epoch_start = time() for step in tqdm(range(steps_per_epoch)): batch = self.train_dh.get_batch(batch_size) sr = self.generator.model.predict(batch['lr']) y_train = [batch['hr']] losses = {} ## Discriminator training if self.discriminator: d_loss_real = self.discriminator.model.train_on_batch( batch['hr'], valid) d_loss_fake = self.discriminator.model.train_on_batch( sr, fake) d_loss_real = dict( zip( [ 'train_d_real_' + m for m in self.discriminator.model.metrics_names ], d_loss_real, )) d_loss_fake = dict( zip( [ 'train_d_fake_' + m for m in self.discriminator.model.metrics_names ], d_loss_fake, )) losses.update(d_loss_real) losses.update(d_loss_fake) y_train.append(valid) ## Generator training if self.feature_extractor: hr_feats = self.feature_extractor.model.predict( batch['hr']) y_train.extend([*hr_feats]) trainig_loss = self.model.train_on_batch(batch['lr'], y_train) losses.update( dict( zip(['train_' + m for m in self.model.metrics_names], trainig_loss))) self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step, losses) self.logger.debug('Losses at step {s}:\n {l}'.format(s=step, l=losses)) elapsed_time = time() - epoch_start self.logger.info('Epoch {} took {:10.1f}s'.format( epoch, elapsed_time)) validation_loss = self.model.evaluate(validation_set['lr'], y_validation, batch_size=batch_size) losses = dict( zip(['val_' + m for m in self.model.metrics_names], validation_loss)) monitored_metrics = {} if (not self.discriminator) and (not self.feature_extractor): monitored_metrics.update({'val_loss': 'min'}) else: monitored_metrics.update({'val_generator_loss': 'min'}) self.helper.on_epoch_end( epoch=epoch, losses=losses, generator=self.model.get_layer('generator'), discriminator=self.discriminator, metrics=monitored_metrics, ) self.tensorboard.on_epoch_end(epoch, losses) self.tensorboard.on_train_end(None)