Пример #1
0
 def __init__(self, lr_dir, hr_dir, patch_size, scale, n_validation_samples=None):
     self.folders = {'hr': hr_dir, 'lr': lr_dir}  # image folders
     self.extensions = ('.png', '.jpeg', '.jpg')  # admissible extension
     self.img_list = {}  # list of file names
     self.n_validation_samples = n_validation_samples
     self.patch_size = patch_size
     self.scale = scale
     self.patch_size = {'lr': patch_size, 'hr': patch_size * self.scale}
     self.logger = get_logger(__name__)
     self._make_img_list()
     self._check_dataset()
Пример #2
0
    def __init__(self, patch_size, layers_to_extract):
        self.patch_size = patch_size
        self.input_shape = (patch_size,) * 2 + (3,)
        self.layers_to_extract = layers_to_extract
        self.logger = get_logger(__name__)

        if len(self.layers_to_extract) > 0:
            self._cut_vgg()
        else:
            self.logger.error('Invalid VGG instantiation: extracted layer must be > 0')
            raise ValueError('Invalid VGG instantiation: extracted layer must be > 0')
Пример #3
0
    def __init__(self, input_dir, output_dir='./data/output', verbose=True):

        self.input_dir = Path(input_dir)
        self.data_name = self.input_dir.name
        self.output_dir = Path(output_dir) / self.data_name
        self.logger = get_logger(__name__)
        if not verbose:
            self.logger.setLevel(40)
        self.extensions = ('.jpeg', '.jpg', '.png')  # file extensions that are admitted
        self.img_ls = [f for f in self.input_dir.iterdir() if f.suffix in self.extensions]
        if len(self.img_ls) < 1:
            self.logger.error('No valid image files found (check config file).')
            raise ValueError('No valid image files found (check config file).')
        # Create results folder
        if not self.output_dir.exists():
            self.logger.info('Creating output directory:\n{}'.format(self.output_dir))
            self.output_dir.mkdir(parents=True)
    def __init__(self, input_dir, output_dir='./data/output', verbose=True):

        self.input_dir = input_dir
        self.data_name = os.path.basename(os.path.normpath(self.input_dir))
        self.output_dir = os.path.join(output_dir, self.data_name)
        file_ls = os.listdir(self.input_dir)
        self.logger = get_logger(__name__)
        if not verbose:
            self.logger.setLevel(40)
        self.extensions = ('.jpeg', '.jpg', '.png')  # file extensions that are admitted
        self.img_ls = [file for file in file_ls if file.endswith(self.extensions)]
        if len(self.img_ls) < 1:
            self.logger.error('No valid image files found (check config file).')
            raise ValueError('No valid image files found (check config file).')
        # Create results folder
        if not os.path.exists(self.output_dir):
            self.logger.info('Creating output directory:\n{}'.format(self.output_dir))
            os.makedirs(self.output_dir, exist_ok=True)
Пример #5
0
    def __init__(
        self,
        generator,
        weights_dir,
        logs_dir,
        lr_train_dir,
        feature_extractor=None,
        discriminator=None,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        fallback_save_every_n_epochs=2,
        max_n_other_weights=5,
        max_n_best_weights=5,
    ):
        self.generator = generator
        self.dirs = {'logs': Path(logs_dir), 'weights': Path(weights_dir)}
        self.feature_extractor = feature_extractor
        self.discriminator = discriminator
        self.dataname = dataname

        if weights_generator:
            self.pretrained_generator_weights = Path(weights_generator)
        else:
            self.pretrained_generator_weights = None

        if weights_discriminator:
            self.pretrained_discriminator_weights = Path(weights_discriminator)
        else:
            self.pretrained_discriminator_weights = None

        self.fallback_save_every_n_epochs = fallback_save_every_n_epochs
        self.lr_dir = Path(lr_train_dir)
        self.basename = self._make_basename()
        self.session_id = self.get_session_id(basename=None)
        self.session_config_name = 'session_config.yml'
        self.callback_paths = self._make_callback_paths()
        self.weights_name = self._weights_name(self.callback_paths)
        self.best_metrics = {}
        self.since_last_epoch = 0
        self.max_n_other_weights = max_n_other_weights
        self.max_n_best_weights = max_n_best_weights
        self.logger = get_logger(__name__)
 def __init__(
     self,
     generator,
     weights_dir,
     logs_dir,
     lr_train_dir,
     feature_extractor=None,
     discriminator=None,
     dataname=None,
     pretrained_weights_path={},
     fallback_save_every_n_epochs=2,
 ):
     self.generator = generator
     self.dirs = {'logs': logs_dir, 'weights': weights_dir}
     self.feature_extractor = feature_extractor
     self.discriminator = discriminator
     self.dataname = dataname
     self.pretrained_weights_path = pretrained_weights_path
     self.lr_dir = lr_train_dir
     self.best_metrics = {}
     self.fallback_save_every_n_epochs = fallback_save_every_n_epochs
     self.since_last_epoch = 0
     self.logger = get_logger(__name__)
Пример #7
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        loss_weights={
            'generator': 1.0,
            'discriminator': 0.003,
            'feature_extractor': 1 / 12
        },
        log_dirs={
            'logs': 'logs',
            'weights': 'weights'
        },
        fallback_save_every_n_epochs=2,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        flatness={
            'min': 0.0,
            'increase_frequency': None,
            'increase': 0.0,
            'max': 0.0
        },
        learning_rate={
            'initial_value': 0.0004,
            'decay_frequency': 100,
            'decay_factor': 0.5
        },
        adam_optimizer={
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': None
        },
        losses={
            'generator': 'mae',
            'discriminator': 'binary_crossentropy',
            'feature_extractor': 'mse',
        },
        metrics={'generator': 'PSNR_Y'},
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.adam_optimizer = adam_optimizer
        self.dataname = dataname
        self.flatness = flatness
        self.n_validation = n_validation
        self.losses = losses
        self.log_dirs = log_dirs
        self.metrics = metrics
        if self.metrics['generator'] == 'PSNR_Y':
            self.metrics['generator'] = PSNR_Y
        elif self.metrics['generator'] == 'PSNR':
            self.metrics['generator'] = PSNR
        self._parameters_sanity_check()
        self.model = self._combine_networks()

        self.settings = {}
        self.settings['training_parameters'] = locals()
        self.settings['training_parameters'][
            'lr_patch_size'] = self.lr_patch_size
        self.settings = self.update_training_config(self.settings)

        self.logger = get_logger(__name__)

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=log_dirs['weights'],
            logs_dir=log_dirs['logs'],
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
        )
Пример #8
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
        fallback_save_every_n_epochs=2,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=0.00001,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.dataname = dataname
        self.T = T
        self.n_validation = n_validation

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.0,
        )
        self.logger = get_logger(__name__)

        self.settings = self.get_training_config()
Пример #9
0
import os
import argparse
import numpy as np
import yaml
from ISR.utils.logger import get_logger


logger = get_logger(__name__)


def _get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--prediction', action='store_true', dest='prediction')
    parser.add_argument('--training', action='store_true', dest='training')
    parser.add_argument('--summary', action='store_true', dest='summary')
    parser.add_argument('--default', action='store_true', dest='default')
    parser.add_argument('--config', action='store', dest='config_file')
    return parser


def parse_args():
    """ Parse CLI arguments. """
    parser = _get_parser()
    args = vars(parser.parse_args())
    if args['prediction'] and args['training']:
        logger.error('Select only prediction OR training.')
        raise ValueError('Select only prediction OR training.')
    return args


def get_config_from_weights(w_path, arch_params, name):
def run(config_file, default=False, training=False, prediction=False):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    logger = get_logger(__name__)
    session_type, generator, conf, dataset = setup(config_file, default, training, prediction)

    lr_patch_size = conf['session'][session_type]['patch_size']
    scale = conf['generators'][generator]['x']

    module = _get_module(generator)
    gen = module.make_model(conf['generators'][generator], lr_patch_size)
    if session_type == 'prediction':
        from ISR.predict.predictor import Predictor

        pr_h = Predictor(input_dir=conf['test_sets'][dataset])
        pr_h.get_predictions(gen, conf['weights_paths']['generator'])

    elif session_type == 'training':
        from ISR.train.trainer import Trainer

        hr_patch_size = lr_patch_size * scale
        if conf['default']['feat_ext']:
            from ISR.models.cut_vgg19 import Cut_VGG19

            out_layers = conf['feat_extr']['vgg19']['layers_to_extract']
            f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers)
        else:
            f_ext = None

        if conf['default']['discriminator']:
            from ISR.models.discriminator import Discriminator

            discr = Discriminator(patch_size=hr_patch_size, kernel_size=3)
        else:
            discr = None

        trainer = Trainer(
            generator=gen,
            discriminator=discr,
            feature_extractor=f_ext,
            lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'],
            hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'],
            lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'],
            hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'],
            loss_weights=conf['loss_weights'],
            dataname=conf['training_sets'][dataset]['data_name'],
            logs_dir=conf['dirs']['logs'],
            weights_dir=conf['dirs']['weights'],
            weights_generator=conf['weights_paths']['generator'],
            weights_discriminator=conf['weights_paths']['discriminator'],
            n_validation=conf['session'][session_type]['n_validation_samples'],
            lr_decay_frequency=conf['session'][session_type]['lr_decay_frequency'],
            lr_decay_factor=conf['session'][session_type]['lr_decay_factor'],
            T=0.01,
        )
        trainer.train(
            epochs=conf['session'][session_type]['epochs'],
            steps_per_epoch=conf['session'][session_type]['steps_per_epoch'],
            batch_size=conf['session'][session_type]['batch_size'],
        )

    else:
        logger.error('Invalid choice.')
Пример #11
0
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.best_metrics = {}
        self.pretrained_weights_path = {
            'generator': weights_generator,
            'discriminator': weights_discriminator,
        }
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            pretrained_weights_path=self.pretrained_weights_path,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.01,
        )
        self.logger = get_logger(__name__)