Example #1
0
 def test_get_validation_batches_requesting_more_than_available(self):
     patch_size = 3
     with patch('os.listdir', side_effect=self.fake_folders):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             try:
                 DH = DataHandler(
                     lr_dir={'res': 'lr', 'matching': True},
                     hr_dir={'res': 'hr', 'matching': True},
                     patch_size=patch_size,
                     scale=2,
                     n_validation_samples=10,
                 )
             except:
                 self.assertTrue(True)
             else:
                 self.assertTrue(False)
Example #2
0
 def test__check_dataset_with_matching_data(self):
     with patch('os.listdir', side_effect=self.fake_folders):
         DH = DataHandler(
             lr_dir={
                 'res': 'lr',
                 'matching': True
             },
             hr_dir={
                 'res': 'hr',
                 'matching': True
             },
             patch_size=0,
             scale=0,
             n_validation_samples=None,
             T=None,
         )
Example #3
0
 def test__transform_batch(self):
     with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None
             )
     I = np.ones((2, 2))
     A = I * 0
     B = I * 1
     C = I * 2
     D = I * 3
     image = np.block([[A, B], [C, D]])
     t_image_1 = np.block([[D, B], [C, A]])
     t_image_2 = np.block([[B, D], [A, C]])
     batch = np.array([image, image])
     expected = np.array([t_image_1, t_image_2])
     self.assertTrue(np.all(DH._transform_batch(batch, [[1, 1], [2, 0]]) == expected))
 def test__crop_imgs_crops_shapes(self):
     with patch('ISR.utils.datahandler.DataHandler._make_img_list',
                return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                    return_value=True):
             DH = DataHandler(lr_dir=None,
                              hr_dir=None,
                              patch_size=3,
                              scale=2,
                              n_validation_samples=None)
     imgs = {
         'hr': np.random.random((20, 20, 3)),
         'lr': np.random.random((10, 10, 3))
     }
     crops = DH._crop_imgs(imgs, batch_size=2, flatness=0)
     self.assertTrue(crops['hr'].shape == (2, 6, 6, 3))
     self.assertTrue(crops['lr'].shape == (2, 3, 3, 3))
Example #5
0
 def test__apply_transorm(self):
     I = np.ones((2, 2))
     A = I * 0
     B = I * 1
     C = I * 2
     D = I * 3
     image = np.block([[A, B], [C, D]])
     with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None
             )
     transf = [[1, 0], [0, 1], [2, 0], [0, 2], [1, 1], [0, 0]]
     self.assertTrue(np.all(np.block([[C, A], [D, B]]) == DH._apply_transform(image, transf[0])))
     self.assertTrue(np.all(np.block([[C, D], [A, B]]) == DH._apply_transform(image, transf[1])))
     self.assertTrue(np.all(np.block([[B, D], [A, C]]) == DH._apply_transform(image, transf[2])))
     self.assertTrue(np.all(np.block([[B, A], [D, C]]) == DH._apply_transform(image, transf[3])))
     self.assertTrue(np.all(np.block([[D, B], [C, A]]) == DH._apply_transform(image, transf[4])))
     self.assertTrue(np.all(image == DH._apply_transform(image, transf[5])))
Example #6
0
 def test_validation_set(self):
     patch_size = 3
     with patch('os.listdir', side_effect=self.fake_folders):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir={'res': 'lr', 'matching': True},
                 hr_dir={'res': 'hr', 'matching': True},
                 patch_size=patch_size,
                 scale=2,
                 n_validation_samples=2,
             )
     
     with patch('imageio.imread', side_effect=self.image_getter):
         with patch('os.path.join', side_effect=self.path_giver):
             batch = DH.get_validation_set(batch_size=12)
     
     self.assertTrue(type(batch) is dict)
     self.assertTrue(len(batch) == 2)
     self.assertTrue(batch['hr'].shape == (24, patch_size * 2, patch_size * 2, 3))
     self.assertTrue(batch['lr'].shape == (24, patch_size, patch_size, 3))
Example #7
0
 def test_get_validation_batches_invalid_number_of_samples(self):
     patch_size = 3
     with patch('os.listdir', side_effect=self.fake_folders):
         with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True):
             DH = DataHandler(
                 lr_dir={'res': 'lr', 'matching': True},
                 hr_dir={'res': 'hr', 'matching': True},
                 patch_size=patch_size,
                 scale=2,
                 n_validation_samples=None,
             )
     
     with patch('imageio.imread', side_effect=self.image_getter):
         with patch('os.path.join', side_effect=self.path_giver):
             try:
                 with patch('raise', None):
                     batch = DH.get_validation_batches(batch_size=5)
             except:
                 self.assertTrue(True)
             else:
                 self.assertTrue(False)
    def test__check_dataset_with_mismatching_data(self):
        try:
            with patch('os.listdir', side_effect=self.fake_folders):

                DH = DataHandler(
                    lr_dir={
                        'res': 'lr',
                        'matching': False
                    },
                    hr_dir={
                        'res': 'hr',
                        'matching': False
                    },
                    patch_size=0,
                    scale=0,
                    n_validation_samples=None,
                )
        except:
            self.assertTrue(True)
        else:
            self.assertTrue(False)
    def test__make_img_list_non_validation(self):
        with patch('os.listdir', side_effect=self.fake_folders):
            with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                       return_value=True):
                DH = DataHandler(
                    lr_dir={
                        'res': 'lr',
                        'matching': False
                    },
                    hr_dir={
                        'res': 'hr',
                        'matching': False
                    },
                    patch_size=0,
                    scale=0,
                    n_validation_samples=None,
                )

        expected_ls = {'hr': ['data0.jpeg', 'data1.png'], 'lr': ['data1.png']}
        self.assertTrue(np.all(DH.img_list['hr'] == expected_ls['hr']))
        self.assertTrue(np.all(DH.img_list['lr'] == expected_ls['lr']))
Example #10
0
    def test_get_batch_shape_and_diversity(self):
        patch_size = 3
        with patch('os.listdir', side_effect=self.fake_folders):
            with patch('ISR.utils.datahandler.DataHandler._check_dataset',
                       return_value=True):
                DH = DataHandler(
                    lr_dir={
                        'res': 'lr',
                        'matching': True
                    },
                    hr_dir={
                        'res': 'hr',
                        'matching': True
                    },
                    patch_size=patch_size,
                    scale=2,
                    n_validation_samples=None,
                    T=0.0,
                )

        with patch('imageio.imread', side_effect=self.image_getter):
            with patch('os.path.join', side_effect=self.path_giver):
                batch = DH.get_batch(batch_size=5, idx=None)

        self.assertTrue(type(batch) is dict)
        self.assertTrue(batch['hr'].shape == (5, patch_size * 2,
                                              patch_size * 2, 3))
        self.assertTrue(batch['lr'].shape == (5, patch_size, patch_size, 3))

        self.assertTrue(
            np.any([
                batch['lr'][0] != batch['lr'][1],
                batch['lr'][1] != batch['lr'][2],
                batch['lr'][2] != batch['lr'][3],
                batch['lr'][3] != batch['lr'][4],
            ]))
Example #11
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        loss_weights={
            'generator': 1.0,
            'discriminator': 0.003,
            'feature_extractor': 1 / 12
        },
        log_dirs={
            'logs': 'logs',
            'weights': 'weights'
        },
        fallback_save_every_n_epochs=2,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        flatness={
            'min': 0.0,
            'increase_frequency': None,
            'increase': 0.0,
            'max': 0.0
        },
        learning_rate={
            'initial_value': 0.0004,
            'decay_frequency': 100,
            'decay_factor': 0.5
        },
        adam_optimizer={
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': None
        },
        losses={
            'generator': 'mae',
            'discriminator': 'binary_crossentropy',
            'feature_extractor': 'mse',
        },
        metrics={'generator': 'PSNR_Y'},
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.adam_optimizer = adam_optimizer
        self.dataname = dataname
        self.flatness = flatness
        self.n_validation = n_validation
        self.losses = losses
        self.log_dirs = log_dirs
        self.metrics = metrics
        if self.metrics['generator'] == 'PSNR_Y':
            self.metrics['generator'] = PSNR_Y
        elif self.metrics['generator'] == 'PSNR':
            self.metrics['generator'] = PSNR
        self._parameters_sanity_check()
        self.model = self._combine_networks()

        self.settings = {}
        self.settings['training_parameters'] = locals()
        self.settings['training_parameters'][
            'lr_patch_size'] = self.lr_patch_size
        self.settings = self.update_training_config(self.settings)

        self.logger = get_logger(__name__)

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=log_dirs['weights'],
            logs_dir=log_dirs['logs'],
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
        )
Example #12
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
        fallback_save_every_n_epochs=2,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=0.00001,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.dataname = dataname
        self.T = T
        self.n_validation = n_validation

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.0,
        )
        self.logger = get_logger(__name__)

        self.settings = self.get_training_config()
Example #13
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.best_metrics = {}
        self.pretrained_weights_path = {
            'generator': weights_generator,
            'discriminator': weights_discriminator,
        }
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            pretrained_weights_path=self.pretrained_weights_path,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.01,
        )
        self.logger = get_logger(__name__)