Пример #1
0
    def test_get_validation_batches_invalid_number_of_samples(self):
        patch_size = 3
        with patch('os.listdir', side_effect=self.fake_folders):
            with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                       return_value=True):
                DH = DataHandler(
                    lr_dir={
                        'res': 'lr',
                        'matching': True
                    },
                    hr_dir={
                        'res': 'hr',
                        'matching': True
                    },
                    patch_size=patch_size,
                    scale=2,
                    n_validation_samples=None,
                )

        with patch('imageio.imread', side_effect=self.image_getter):
            with patch('os.path.join', side_effect=self.path_giver):
                try:
                    with patch('raise', None):
                        batch = DH.get_validation_batches(batch_size=5)
                except:
                    self.assertTrue(True)
                else:
                    self.assertTrue(False)
Пример #2
0
    def test_get_validation_batches_valid_request(self):
        patch_size = 3
        with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                   return_value=True):
            with patch('os.listdir', side_effect=self.fake_folders):

                DH = DataHandler(
                    lr_dir={
                        'res': 'lr',
                        'matching': True
                    },
                    hr_dir={
                        'res': 'hr',
                        'matching': True
                    },
                    patch_size=patch_size,
                    scale=2,
                    n_validation_samples=2,
                )

        with patch('imageio.imread', side_effect=self.image_getter):
            with patch('os.path.join', side_effect=self.path_giver):
                batch = DH.get_validation_batches(batch_size=12)

        self.assertTrue(len(batch) == 2)
        self.assertTrue(type(batch) is list)
        self.assertTrue(type(batch[0]) is dict)
        self.assertTrue(batch[0]['hr'].shape == (12, patch_size * 2,
                                                 patch_size * 2, 3))
        self.assertTrue(batch[0]['lr'].shape == (12, patch_size, patch_size,
                                                 3))
        self.assertTrue(batch[1]['hr'].shape == (12, patch_size * 2,
                                                 patch_size * 2, 3))
        self.assertTrue(batch[1]['lr'].shape == (12, patch_size, patch_size,
                                                 3))
Пример #3
0
    def test_get_batch_shape_and_diversity(self):
        patch_size = 3
        with patch('os.listdir', side_effect=self.fake_folders):
            with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                       return_value=True):
                DH = DataHandler(
                    lr_dir={
                        'res': 'lr',
                        'matching': True
                    },
                    hr_dir={
                        'res': 'hr',
                        'matching': True
                    },
                    patch_size=patch_size,
                    scale=2,
                    n_validation_samples=None,
                )

        with patch('imageio.imread', side_effect=self.image_getter):
            with patch('os.path.join', side_effect=self.path_giver):
                batch = DH.get_batch(batch_size=5)

        self.assertTrue(type(batch) is dict)
        self.assertTrue(batch['hr'].shape == (5, patch_size * 2,
                                              patch_size * 2, 3))
        self.assertTrue(batch['lr'].shape == (5, patch_size, patch_size, 3))

        self.assertTrue(
            np.any([
                batch['lr'][0] != batch['lr'][1],
                batch['lr'][1] != batch['lr'][2],
                batch['lr'][2] != batch['lr'][3],
                batch['lr'][3] != batch['lr'][4],
            ]))
Пример #4
0
 def test__not_flat_with_non_flat_patch(self):
     lr_patch = np.random.random((5, 5, 3))
     with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir=None, hr_dir=None, patch_size=0, scale=0, n_validation_samples=None
             )
     self.assertTrue(DH._not_flat(lr_patch, flatness=0.00001))
Пример #5
0
 def test__crop_imgs_crops_shapes(self):
     with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None
             )
     imgs = {'hr': np.random.random((20, 20, 3)), 'lr': np.random.random((10, 10, 3))}
     crops = DH._crop_imgs(imgs, batch_size=2, flatness=0)
     self.assertTrue(crops['hr'].shape == (2, 6, 6, 3))
     self.assertTrue(crops['lr'].shape == (2, 3, 3, 3))
Пример #6
0
 def test__not_flat_with_flat_patch(self):
     lr_patch = np.zeros((5, 5, 3))
     with patch('ISR.utils.datahandler.DataHandler._make_img_list',
                return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                    return_value=True):
             DH = DataHandler(
                 lr_dir=None,
                 hr_dir=None,
                 patch_size=0,
                 scale=0,
                 n_validation_samples=None,
                 T=0.01,
             )
     self.assertFalse(DH._not_flat(lr_patch))
Пример #7
0
 def test__transform_batch(self):
     with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None
             )
     I = np.ones((2, 2))
     A = I * 0
     B = I * 1
     C = I * 2
     D = I * 3
     image = np.block([[A, B], [C, D]])
     t_image_1 = np.block([[D, B], [C, A]])
     t_image_2 = np.block([[B, D], [A, C]])
     batch = np.array([image, image])
     expected = np.array([t_image_1, t_image_2])
     self.assertTrue(np.all(DH._transform_batch(batch, [[1, 1], [2, 0]]) == expected))
Пример #8
0
 def test__check_dataset_with_matching_data(self):
     with patch('os.listdir', side_effect=self.fake_folders):
         DH = DataHandler(
             lr_dir={'res': 'lr', 'matching': True},
             hr_dir={'res': 'hr', 'matching': True},
             patch_size=0,
             scale=0,
             n_validation_samples=None,
         )
Пример #9
0
 def test__make_img_list_non_validation(self):
     with patch('os.listdir', side_effect=self.fake_folders):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir={'res': 'lr', 'matching': False},
                 hr_dir={'res': 'hr', 'matching': False},
                 patch_size=0,
                 scale=0,
                 n_validation_samples=None,
             )
     
     expected_ls = {'hr': ['data0.jpeg', 'data1.png'], 'lr': ['data1.png']}
     self.assertTrue(np.all(DH.img_list['hr'] == expected_ls['hr']))
     self.assertTrue(np.all(DH.img_list['lr'] == expected_ls['lr']))
Пример #10
0
 def test__check_dataset_with_mismatching_data(self):
     try:
         with patch('os.listdir', side_effect=self.fake_folders):
             
             DH = DataHandler(
                 lr_dir={'res': 'lr', 'matching': False},
                 hr_dir={'res': 'hr', 'matching': False},
                 patch_size=0,
                 scale=0,
                 n_validation_samples=None,
             )
     except:
         self.assertTrue(True)
     else:
         self.assertTrue(False)
Пример #11
0
 def test_get_validation_batches_requesting_more_than_available(self):
     patch_size = 3
     with patch('os.listdir', side_effect=self.fake_folders):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             try:
                 DH = DataHandler(
                     lr_dir={'res': 'lr', 'matching': True},
                     hr_dir={'res': 'hr', 'matching': True},
                     patch_size=patch_size,
                     scale=2,
                     n_validation_samples=10,
                 )
             except:
                 self.assertTrue(True)
             else:
                 self.assertTrue(False)
Пример #12
0
 def test__apply_transorm(self):
     I = np.ones((2, 2))
     A = I * 0
     B = I * 1
     C = I * 2
     D = I * 3
     image = np.block([[A, B], [C, D]])
     with patch('ISR.utils.datahandler.DataHandler._make_img_list',
                return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                    return_value=True):
             DH = DataHandler(
                 lr_dir=None,
                 hr_dir=None,
                 patch_size=3,
                 scale=2,
                 n_validation_samples=None,
                 T=0.0,
             )
     transf = [[1, 0], [0, 1], [2, 0], [0, 2], [1, 1], [0, 0]]
     self.assertTrue(
         np.all(
             np.block([[C, A], [D, B]]) == DH._apply_transform(
                 image, transf[0])))
     self.assertTrue(
         np.all(
             np.block([[C, D], [A, B]]) == DH._apply_transform(
                 image, transf[1])))
     self.assertTrue(
         np.all(
             np.block([[B, D], [A, C]]) == DH._apply_transform(
                 image, transf[2])))
     self.assertTrue(
         np.all(
             np.block([[B, A], [D, C]]) == DH._apply_transform(
                 image, transf[3])))
     self.assertTrue(
         np.all(
             np.block([[D, B], [C, A]]) == DH._apply_transform(
                 image, transf[4])))
     self.assertTrue(np.all(image == DH._apply_transform(image, transf[5])))
Пример #13
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        loss_weights={
            'generator': 1.0,
            'discriminator': 0.003,
            'feature_extractor': 1 / 12
        },
        log_dirs={
            'logs': 'logs',
            'weights': 'weights'
        },
        fallback_save_every_n_epochs=2,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        flatness={
            'min': 0.0,
            'increase_frequency': None,
            'increase': 0.0,
            'max': 0.0
        },
        learning_rate={
            'initial_value': 0.0004,
            'decay_frequency': 100,
            'decay_factor': 0.5
        },
        adam_optimizer={
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': None
        },
        losses={
            'generator': 'mae',
            'discriminator': 'binary_crossentropy',
            'feature_extractor': 'mse',
        },
        metrics={'generator': 'PSNR_Y'},
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.adam_optimizer = adam_optimizer
        self.dataname = dataname
        self.flatness = flatness
        self.n_validation = n_validation
        self.losses = losses
        self.log_dirs = log_dirs
        self.metrics = metrics
        if self.metrics['generator'] == 'PSNR_Y':
            self.metrics['generator'] = PSNR_Y
        elif self.metrics['generator'] == 'PSNR':
            self.metrics['generator'] = PSNR
        self._parameters_sanity_check()
        self.model = self._combine_networks()

        self.settings = {}
        self.settings['training_parameters'] = locals()
        self.settings['training_parameters'][
            'lr_patch_size'] = self.lr_patch_size
        self.settings = self.update_training_config(self.settings)

        self.logger = get_logger(__name__)

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=log_dirs['weights'],
            logs_dir=log_dirs['logs'],
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
        )
Пример #14
0
class Trainer:
    """Class object to setup and carry the training.

    Takes as input a generator that produces SR images.
    Conditionally, also a discriminator network and a feature extractor
        to build the components of the perceptual loss.
    Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise
    carries a regular ISR training.

    Args:
        generator: Keras model, the super-scaling, or generator, network.
        discriminator: Keras model, the discriminator network for the adversarial
            component of the perceptual loss.
        feature_extractor: Keras model, feature extractor network for the deep features
            component of perceptual loss function.
        lr_train_dir: path to the directory containing the Low-Res images for training.
        hr_train_dir: path to the directory containing the High-Res images for training.
        lr_valid_dir: path to the directory containing the Low-Res images for validation.
        hr_valid_dir: path to the directory containing the High-Res images for validation.
        learning_rate: float.
        loss_weights: dictionary, use to weigh the components of the loss function.
            Contains 'generator' for the generator loss component, and can contain 'discriminator' and 'feature_extractor'
            for the discriminator and deep features components respectively.
        logs_dir: path to the directory where the tensorboard logs are saved.
        weights_dir: path to the directory where the weights are saved.
        dataname: string, used to identify what dataset is used for the training session.
        weights_generator: path to the pre-trained generator's weights, for transfer learning.
        weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning.
        n_validation:integer, number of validation samples used at training from the validation set.
        flatness: dictionary. Determines determines the 'flatness' threshold level for the training patches.
            See the TrainerHelper class for more details.
        lr_decay_frequency: integer, every how many epochs the learning rate is reduced.
        lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor.

    Methods:
        train: combines the networks and triggers training with the specified settings.

    """
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        loss_weights={
            'generator': 1.0,
            'discriminator': 0.003,
            'feature_extractor': 1 / 12
        },
        log_dirs={
            'logs': 'logs',
            'weights': 'weights'
        },
        fallback_save_every_n_epochs=2,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        flatness={
            'min': 0.0,
            'increase_frequency': None,
            'increase': 0.0,
            'max': 0.0
        },
        learning_rate={
            'initial_value': 0.0004,
            'decay_frequency': 100,
            'decay_factor': 0.5
        },
        adam_optimizer={
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': None
        },
        losses={
            'generator': 'mae',
            'discriminator': 'binary_crossentropy',
            'feature_extractor': 'mse',
        },
        metrics={'generator': 'PSNR_Y'},
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.adam_optimizer = adam_optimizer
        self.dataname = dataname
        self.flatness = flatness
        self.n_validation = n_validation
        self.losses = losses
        self.log_dirs = log_dirs
        self.metrics = metrics
        if self.metrics['generator'] == 'PSNR_Y':
            self.metrics['generator'] = PSNR_Y
        elif self.metrics['generator'] == 'PSNR':
            self.metrics['generator'] = PSNR
        self._parameters_sanity_check()
        self.model = self._combine_networks()

        self.settings = {}
        self.settings['training_parameters'] = locals()
        self.settings['training_parameters'][
            'lr_patch_size'] = self.lr_patch_size
        self.settings = self.update_training_config(self.settings)

        self.logger = get_logger(__name__)

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=log_dirs['weights'],
            logs_dir=log_dirs['logs'],
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
        )

    def _parameters_sanity_check(self):
        """ Parameteres sanity check. """

        if self.discriminator:
            assert self.lr_patch_size * self.scale == self.discriminator.patch_size
            self.adam_optimizer
        if self.feature_extractor:
            assert self.lr_patch_size * self.scale == self.feature_extractor.patch_size

        check_parameter_keys(
            self.learning_rate,
            needed_keys=['initial_value'],
            optional_keys=['decay_factor', 'decay_frequency'],
            default_value=None,
        )
        check_parameter_keys(
            self.flatness,
            needed_keys=[],
            optional_keys=['min', 'increase_frequency', 'increase', 'max'],
            default_value=0.0,
        )
        check_parameter_keys(
            self.adam_optimizer,
            needed_keys=['beta1', 'beta2'],
            optional_keys=['epsilon'],
            default_value=None,
        )
        check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights'])

    def _combine_networks(self):
        """
        Constructs the combined model which contains the generator network,
        as well as discriminator and geature extractor, if any are defined.
        """

        lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, ))
        sr = self.generator.model(lr)
        outputs = [sr]
        losses = [self.losses['generator']]
        loss_weights = [self.loss_weights['generator']]

        if self.discriminator:
            self.discriminator.model.trainable = False
            validity = self.discriminator.model(sr)
            outputs.append(validity)
            losses.append(self.losses['discriminator'])
            loss_weights.append(self.loss_weights['discriminator'])
        if self.feature_extractor:
            self.feature_extractor.model.trainable = False
            sr_feats = self.feature_extractor.model(sr)
            outputs.extend([*sr_feats])
            losses.extend([self.losses['feature_extractor']] * len(sr_feats))
            loss_weights.extend(
                [self.loss_weights['feature_extractor'] / len(sr_feats)] *
                len(sr_feats))
        combined = Model(inputs=lr, outputs=outputs)
        # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows
        optimizer = Adam(
            beta_1=self.adam_optimizer['beta1'],
            beta_2=self.adam_optimizer['beta2'],
            lr=self.learning_rate['initial_value'],
            epsilon=self.adam_optimizer['epsilon'],
        )
        combined.compile(loss=losses,
                         loss_weights=loss_weights,
                         optimizer=optimizer,
                         metrics=self.metrics)
        return combined

    def _lr_scheduler(self, epoch):
        """ Scheduler for the learning rate updates. """

        n_decays = epoch // self.learning_rate['decay_frequency']
        lr = self.learning_rate['initial_value'] * (
            self.learning_rate['decay_factor']**n_decays)
        # no lr below minimum control 10e-7
        return max(1e-7, lr)

    def _flatness_scheduler(self, epoch):
        if self.flatness['increase']:
            n_increases = epoch // self.flatness['increase_frequency']
        else:
            return self.flatness['min']

        f = self.flatness['min'] + n_increases * self.flatness['increase']

        return min(self.flatness['max'], f)

    def _load_weights(self):
        """
        Loads the pretrained weights from the given path, if any is provided.
        If a discriminator is defined, does the same.
        """

        if self.weights_generator:
            self.model.get_layer('generator').load_weights(
                self.weights_generator)

        if self.discriminator:
            if self.weights_discriminator:
                self.model.get_layer('discriminator').load_weights(
                    self.weights_discriminator)
                self.discriminator.model.load_weights(
                    self.weights_discriminator)

    def _format_losses(self, prefix, losses, model_metrics):
        """ Creates a dictionary for tensorboard tracking. """

        return dict(zip([prefix + m for m in model_metrics], losses))

    def update_training_config(self, settings):
        """ Summarizes training setting. """

        _ = settings['training_parameters'].pop('weights_generator')
        _ = settings['training_parameters'].pop('self')
        _ = settings['training_parameters'].pop('generator')
        _ = settings['training_parameters'].pop('discriminator')
        _ = settings['training_parameters'].pop('feature_extractor')
        settings['generator'] = {}
        settings['generator']['name'] = self.generator.name
        settings['generator']['parameters'] = self.generator.params
        settings['generator']['weights_generator'] = self.weights_generator

        _ = settings['training_parameters'].pop('weights_discriminator')
        if self.discriminator:
            settings['discriminator'] = {}
            settings['discriminator']['name'] = self.discriminator.name
            settings['discriminator'][
                'weights_discriminator'] = self.weights_discriminator
        else:
            settings['discriminator'] = None

        if self.discriminator:
            settings['feature_extractor'] = {}
            settings['feature_extractor']['name'] = self.feature_extractor.name
            settings['feature_extractor'][
                'layers'] = self.feature_extractor.layers_to_extract
        else:
            settings['feature_extractor'] = None

        return settings

    def train(self, epochs, steps_per_epoch, batch_size, monitored_metrics):
        """
        Carries on the training for the given number of epochs.
        Sends the losses to Tensorboard.

        Args:
            epochs: how many epochs to train for.
            steps_per_epoch: how many batches epoch.
            batch_size: amount of images per batch.
            monitored_metrics: dictionary, the keys are the metrics that are monitored for the weights
                saving logic. The values are the mode that trigger the weights saving ('min' vs 'max').
        """

        self.settings['training_parameters'][
            'steps_per_epoch'] = steps_per_epoch
        self.settings['training_parameters']['batch_size'] = batch_size
        starting_epoch = self.helper.initialize_training(
            self)  # load_weights, creates folders, creates basename

        self.tensorboard = TensorBoard(
            log_dir=self.helper.callback_paths['logs'])
        self.tensorboard.set_model(self.model)

        # validation data
        validation_set = self.valid_dh.get_validation_set(batch_size)
        y_validation = [validation_set['hr']]
        if self.discriminator:
            discr_out_shape = list(
                self.discriminator.model.outputs[0].shape)[1:4]
            valid = np.ones([batch_size] + discr_out_shape)
            fake = np.zeros([batch_size] + discr_out_shape)
            validation_valid = np.ones([len(validation_set['hr'])] +
                                       discr_out_shape)
            y_validation.append(validation_valid)
        if self.feature_extractor:
            validation_feats = self.feature_extractor.model.predict(
                validation_set['hr'])
            y_validation.extend([*validation_feats])

        for epoch in range(starting_epoch, epochs):
            self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch,
                                                          tot_eps=epochs))
            K.set_value(self.model.optimizer.lr,
                        self._lr_scheduler(epoch=epoch))
            self.logger.info('Current learning rate: {}'.format(
                K.eval(self.model.optimizer.lr)))

            flatness = self._flatness_scheduler(epoch)
            if flatness:
                self.logger.info(
                    'Current flatness treshold: {}'.format(flatness))

            epoch_start = time()
            for step in tqdm(range(steps_per_epoch)):
                batch = self.train_dh.get_batch(batch_size, flatness=flatness)
                y_train = [batch['hr']]
                training_losses = {}

                ## Discriminator training
                if self.discriminator:
                    sr = self.generator.model.predict(batch['lr'])
                    d_loss_real = self.discriminator.model.train_on_batch(
                        batch['hr'], valid)
                    d_loss_fake = self.discriminator.model.train_on_batch(
                        sr, fake)
                    d_loss_fake = self._format_losses(
                        'train_d_fake_', d_loss_fake,
                        self.discriminator.model.metrics_names)
                    d_loss_real = self._format_losses(
                        'train_d_real_', d_loss_real,
                        self.discriminator.model.metrics_names)
                    training_losses.update(d_loss_real)
                    training_losses.update(d_loss_fake)
                    y_train.append(valid)

                ## Generator training
                if self.feature_extractor:
                    hr_feats = self.feature_extractor.model.predict(
                        batch['hr'])
                    y_train.extend([*hr_feats])

                model_losses = self.model.train_on_batch(batch['lr'], y_train)
                model_losses = self._format_losses('train_', model_losses,
                                                   self.model.metrics_names)
                training_losses.update(model_losses)

                self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step,
                                              training_losses)
                self.logger.debug('Losses at step {s}:\n {l}'.format(
                    s=step, l=training_losses))

            elapsed_time = time() - epoch_start
            self.logger.info('Epoch {} took {:10.1f}s'.format(
                epoch, elapsed_time))

            validation_losses = self.model.evaluate(validation_set['lr'],
                                                    y_validation,
                                                    batch_size=batch_size)
            validation_losses = self._format_losses('val_', validation_losses,
                                                    self.model.metrics_names)

            if epoch == starting_epoch:
                remove_metrics = []
                for metric in monitored_metrics:
                    if (metric not in training_losses) and (
                            metric not in validation_losses):
                        msg = ' '.join([
                            metric,
                            'is NOT among the model metrics, removing it.'
                        ])
                        self.logger.error(msg)
                        remove_metrics.append(metric)
                for metric in remove_metrics:
                    _ = monitored_metrics.pop(metric)

            # should average train metrics
            end_losses = {}
            end_losses.update(validation_losses)
            end_losses.update(training_losses)

            self.helper.on_epoch_end(
                epoch=epoch,
                losses=end_losses,
                generator=self.model.get_layer('generator'),
                discriminator=self.discriminator,
                metrics=monitored_metrics,
            )
            self.tensorboard.on_epoch_end(epoch, validation_losses)

        self.tensorboard.on_train_end(None)
Пример #15
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
        fallback_save_every_n_epochs=2,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=0.00001,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.dataname = dataname
        self.T = T
        self.n_validation = n_validation

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.0,
        )
        self.logger = get_logger(__name__)

        self.settings = self.get_training_config()
Пример #16
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.best_metrics = {}
        self.pretrained_weights_path = {
            'generator': weights_generator,
            'discriminator': weights_discriminator,
        }
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            pretrained_weights_path=self.pretrained_weights_path,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.01,
        )
        self.logger = get_logger(__name__)
Пример #17
0
class Trainer:
    """Class object to setup and carry the training.

    Takes as input a generator that produces SR images.
    Conditionally, also a discriminator network and a feature extractor
        to build the components of the perceptual loss.
    Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise
    carries a regular ISR training.

    Args:
        generator: Keras model, the super-scaling, or generator, network.
        discriminator: Keras model, the discriminator network for the adversarial
            component of the perceptual loss.
        feature_extractor: Keras model, feature extractor network for the deep features
            component of perceptual loss function.
        lr_train_dir: path to the directory containing the Low-Res images for training.
        hr_train_dir: path to the directory containing the High-Res images for training.
        lr_valid_dir: path to the directory containing the Low-Res images for validation.
        hr_valid_dir: path to the directory containing the High-Res images for validation.
        learning_rate: float.
        loss_weights: dictionary, use to weigh the components of the loss function.
            Contains 'MSE' for the MSE loss component, and can contain 'discriminator' and 'feat_extr'
            for the discriminator and deep features components respectively.
        logs_dir: path to the directory where the tensorboard logs are saved.
        weights_dir: path to the directory where the weights are saved.
        dataname: string, used to identify what dataset is used for the training session.
        weights_generator: path to the pre-trained generator's weights, for transfer learning.
        weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning.
        n_validation:integer, number of validation samples used at training from the validation set.
        T: 0 < float <1, determines the 'flatness' threshold level for the training patches.
            See the TrainerHelper class for more details.
        lr_decay_frequency: integer, every how many epochs the learning rate is reduced.
        lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor.

    Methods:
        train: combines the networks and triggers training with the specified settings.

    """
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.best_metrics = {}
        self.pretrained_weights_path = {
            'generator': weights_generator,
            'discriminator': weights_discriminator,
        }
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            pretrained_weights_path=self.pretrained_weights_path,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.01,
        )
        self.logger = get_logger(__name__)

    def _combine_networks(self):
        """
        Constructs the combined model which contains the generator network,
        as well as discriminator and geature extractor, if any are defined.
        """

        lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, ))
        sr = self.generator.model(lr)
        outputs = [sr]
        losses = ['mse']
        loss_weights = [self.loss_weights['MSE']]
        if self.discriminator:
            self.discriminator.model.trainable = False
            validity = self.discriminator.model(sr)
            outputs.append(validity)
            losses.append('binary_crossentropy')
            loss_weights.append(self.loss_weights['discriminator'])
        if self.feature_extractor:
            self.feature_extractor.model.trainable = False
            sr_feats = self.feature_extractor.model(sr)
            outputs.extend([*sr_feats])
            losses.extend(['mse'] * len(sr_feats))
            loss_weights.extend(
                [self.loss_weights['feat_extr'] / len(sr_feats)] *
                len(sr_feats))
        combined = Model(inputs=lr, outputs=outputs)
        # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows
        optimizer = Adam(epsilon=0.0000001)
        combined.compile(loss=losses,
                         loss_weights=loss_weights,
                         optimizer=optimizer,
                         metrics={'generator': PSNR})
        return combined

    def _lr_scheduler(self, epoch):
        """ Scheduler for the learning rate updates. """

        n_decays = epoch // self.lr_decay_frequency
        # no lr below minimum control 10e-6
        return max(1e-6, self.learning_rate * (self.lr_decay_factor**n_decays))

    def _load_weights(self):
        """
        Loads the pretrained weights from the given path, if any is provided.
        If a discriminator is defined, does the same.
        """

        gen_w = self.pretrained_weights_path['generator']
        if gen_w:
            self.model.get_layer('generator').load_weights(gen_w)
        if self.discriminator:
            dis_w = self.pretrained_weights_path['discriminator']
            if dis_w:
                self.model.get_layer('discriminator').load_weights(dis_w)
                self.discriminator.model.load_weights(dis_w)

    def train(self, epochs, steps_per_epoch, batch_size):
        """
        Carries on the training for the given number of epochs.
        Sends the losses to Tensorboard.
        """

        starting_epoch = self.helper.initialize_training(
            self)  # load_weights, creates folders, creates basename
        self.tensorboard = TensorBoard(
            log_dir=self.helper.callback_paths['logs'])
        self.tensorboard.set_model(self.model)

        # validation data
        validation_set = self.valid_dh.get_validation_set(batch_size)
        y_validation = [validation_set['hr']]
        if self.discriminator:
            discr_out_shape = list(
                self.discriminator.model.outputs[0].shape)[1:4]
            valid = np.ones([batch_size] + discr_out_shape)
            fake = np.zeros([batch_size] + discr_out_shape)
            validation_valid = np.ones([len(validation_set['hr'])] +
                                       discr_out_shape)
            y_validation.append(validation_valid)
        if self.feature_extractor:
            validation_feats = self.feature_extractor.model.predict(
                validation_set['hr'])
            y_validation.extend([*validation_feats])

        for epoch in range(starting_epoch, epochs):
            self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch,
                                                          tot_eps=epochs))
            K.set_value(self.model.optimizer.lr,
                        self._lr_scheduler(epoch=epoch))
            self.logger.info('Current learning rate: {}'.format(
                K.eval(self.model.optimizer.lr)))
            epoch_start = time()
            for step in tqdm(range(steps_per_epoch)):
                batch = self.train_dh.get_batch(batch_size)
                sr = self.generator.model.predict(batch['lr'])
                y_train = [batch['hr']]
                losses = {}

                ## Discriminator training
                if self.discriminator:
                    d_loss_real = self.discriminator.model.train_on_batch(
                        batch['hr'], valid)
                    d_loss_fake = self.discriminator.model.train_on_batch(
                        sr, fake)
                    d_loss_real = dict(
                        zip(
                            [
                                'train_d_real_' + m
                                for m in self.discriminator.model.metrics_names
                            ],
                            d_loss_real,
                        ))
                    d_loss_fake = dict(
                        zip(
                            [
                                'train_d_fake_' + m
                                for m in self.discriminator.model.metrics_names
                            ],
                            d_loss_fake,
                        ))
                    losses.update(d_loss_real)
                    losses.update(d_loss_fake)
                    y_train.append(valid)

                ## Generator training
                if self.feature_extractor:
                    hr_feats = self.feature_extractor.model.predict(
                        batch['hr'])
                    y_train.extend([*hr_feats])

                trainig_loss = self.model.train_on_batch(batch['lr'], y_train)
                losses.update(
                    dict(
                        zip(['train_' + m for m in self.model.metrics_names],
                            trainig_loss)))
                self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step,
                                              losses)
                self.logger.debug('Losses at step {s}:\n {l}'.format(s=step,
                                                                     l=losses))

            elapsed_time = time() - epoch_start
            self.logger.info('Epoch {} took {:10.1f}s'.format(
                epoch, elapsed_time))

            validation_loss = self.model.evaluate(validation_set['lr'],
                                                  y_validation,
                                                  batch_size=batch_size)
            losses = dict(
                zip(['val_' + m for m in self.model.metrics_names],
                    validation_loss))

            monitored_metrics = {}
            if (not self.discriminator) and (not self.feature_extractor):
                monitored_metrics.update({'val_loss': 'min'})
            else:
                monitored_metrics.update({'val_generator_loss': 'min'})

            self.helper.on_epoch_end(
                epoch=epoch,
                losses=losses,
                generator=self.model.get_layer('generator'),
                discriminator=self.discriminator,
                metrics=monitored_metrics,
            )
            self.tensorboard.on_epoch_end(epoch, losses)
        self.tensorboard.on_train_end(None)