def __init__(self, lr_dir, hr_dir, patch_size, scale, n_validation_samples=None): self.folders = {'hr': hr_dir, 'lr': lr_dir} # image folders self.extensions = ('.png', '.jpeg', '.jpg') # admissible extension self.img_list = {} # list of file names self.n_validation_samples = n_validation_samples self.patch_size = patch_size self.scale = scale self.patch_size = {'lr': patch_size, 'hr': patch_size * self.scale} self.logger = get_logger(__name__) self._make_img_list() self._check_dataset()
def __init__(self, patch_size, layers_to_extract): self.patch_size = patch_size self.input_shape = (patch_size,) * 2 + (3,) self.layers_to_extract = layers_to_extract self.logger = get_logger(__name__) if len(self.layers_to_extract) > 0: self._cut_vgg() else: self.logger.error('Invalid VGG instantiation: extracted layer must be > 0') raise ValueError('Invalid VGG instantiation: extracted layer must be > 0')
def __init__(self, input_dir, output_dir='./data/output', verbose=True): self.input_dir = Path(input_dir) self.data_name = self.input_dir.name self.output_dir = Path(output_dir) / self.data_name self.logger = get_logger(__name__) if not verbose: self.logger.setLevel(40) self.extensions = ('.jpeg', '.jpg', '.png') # file extensions that are admitted self.img_ls = [f for f in self.input_dir.iterdir() if f.suffix in self.extensions] if len(self.img_ls) < 1: self.logger.error('No valid image files found (check config file).') raise ValueError('No valid image files found (check config file).') # Create results folder if not self.output_dir.exists(): self.logger.info('Creating output directory:\n{}'.format(self.output_dir)) self.output_dir.mkdir(parents=True)
def __init__(self, input_dir, output_dir='./data/output', verbose=True): self.input_dir = input_dir self.data_name = os.path.basename(os.path.normpath(self.input_dir)) self.output_dir = os.path.join(output_dir, self.data_name) file_ls = os.listdir(self.input_dir) self.logger = get_logger(__name__) if not verbose: self.logger.setLevel(40) self.extensions = ('.jpeg', '.jpg', '.png') # file extensions that are admitted self.img_ls = [file for file in file_ls if file.endswith(self.extensions)] if len(self.img_ls) < 1: self.logger.error('No valid image files found (check config file).') raise ValueError('No valid image files found (check config file).') # Create results folder if not os.path.exists(self.output_dir): self.logger.info('Creating output directory:\n{}'.format(self.output_dir)) os.makedirs(self.output_dir, exist_ok=True)
def __init__( self, generator, weights_dir, logs_dir, lr_train_dir, feature_extractor=None, discriminator=None, dataname=None, weights_generator=None, weights_discriminator=None, fallback_save_every_n_epochs=2, max_n_other_weights=5, max_n_best_weights=5, ): self.generator = generator self.dirs = {'logs': Path(logs_dir), 'weights': Path(weights_dir)} self.feature_extractor = feature_extractor self.discriminator = discriminator self.dataname = dataname if weights_generator: self.pretrained_generator_weights = Path(weights_generator) else: self.pretrained_generator_weights = None if weights_discriminator: self.pretrained_discriminator_weights = Path(weights_discriminator) else: self.pretrained_discriminator_weights = None self.fallback_save_every_n_epochs = fallback_save_every_n_epochs self.lr_dir = Path(lr_train_dir) self.basename = self._make_basename() self.session_id = self.get_session_id(basename=None) self.session_config_name = 'session_config.yml' self.callback_paths = self._make_callback_paths() self.weights_name = self._weights_name(self.callback_paths) self.best_metrics = {} self.since_last_epoch = 0 self.max_n_other_weights = max_n_other_weights self.max_n_best_weights = max_n_best_weights self.logger = get_logger(__name__)
def __init__( self, generator, weights_dir, logs_dir, lr_train_dir, feature_extractor=None, discriminator=None, dataname=None, pretrained_weights_path={}, fallback_save_every_n_epochs=2, ): self.generator = generator self.dirs = {'logs': logs_dir, 'weights': weights_dir} self.feature_extractor = feature_extractor self.discriminator = discriminator self.dataname = dataname self.pretrained_weights_path = pretrained_weights_path self.lr_dir = lr_train_dir self.best_metrics = {} self.fallback_save_every_n_epochs = fallback_save_every_n_epochs self.since_last_epoch = 0 self.logger = get_logger(__name__)
def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights={ 'generator': 1.0, 'discriminator': 0.003, 'feature_extractor': 1 / 12 }, log_dirs={ 'logs': 'logs', 'weights': 'weights' }, fallback_save_every_n_epochs=2, dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, flatness={ 'min': 0.0, 'increase_frequency': None, 'increase': 0.0, 'max': 0.0 }, learning_rate={ 'initial_value': 0.0004, 'decay_frequency': 100, 'decay_factor': 0.5 }, adam_optimizer={ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': None }, losses={ 'generator': 'mae', 'discriminator': 'binary_crossentropy', 'feature_extractor': 'mse', }, metrics={'generator': 'PSNR_Y'}, ): self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.weights_generator = weights_generator self.weights_discriminator = weights_discriminator self.adam_optimizer = adam_optimizer self.dataname = dataname self.flatness = flatness self.n_validation = n_validation self.losses = losses self.log_dirs = log_dirs self.metrics = metrics if self.metrics['generator'] == 'PSNR_Y': self.metrics['generator'] = PSNR_Y elif self.metrics['generator'] == 'PSNR': self.metrics['generator'] = PSNR self._parameters_sanity_check() self.model = self._combine_networks() self.settings = {} self.settings['training_parameters'] = locals() self.settings['training_parameters'][ 'lr_patch_size'] = self.lr_patch_size self.settings = self.update_training_config(self.settings) self.logger = get_logger(__name__) self.helper = TrainerHelper( generator=self.generator, weights_dir=log_dirs['weights'], logs_dir=log_dirs['logs'], lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, weights_generator=self.weights_generator, weights_discriminator=self.weights_discriminator, fallback_save_every_n_epochs=fallback_save_every_n_epochs, ) self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, )
def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, learning_rate=0.0004, loss_weights={'MSE': 1.0}, logs_dir='logs', weights_dir='weights', dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, T=0.01, lr_decay_frequency=100, lr_decay_factor=0.5, fallback_save_every_n_epochs=2, beta_1=0.9, beta_2=0.999, epsilon=0.00001, ): if discriminator: assert generator.patch_size * generator.scale == discriminator.patch_size if feature_extractor: assert generator.patch_size * generator.scale == feature_extractor.patch_size self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.weights_generator = weights_generator self.weights_discriminator = weights_discriminator self.lr_decay_factor = lr_decay_factor self.lr_decay_frequency = lr_decay_frequency self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon self.dataname = dataname self.T = T self.n_validation = n_validation self.helper = TrainerHelper( generator=self.generator, weights_dir=weights_dir, logs_dir=logs_dir, lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, weights_generator=self.weights_generator, weights_discriminator=self.weights_discriminator, fallback_save_every_n_epochs=fallback_save_every_n_epochs, ) self.model = self._combine_networks() self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, T=T, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, T=0.0, ) self.logger = get_logger(__name__) self.settings = self.get_training_config()
import os import argparse import numpy as np import yaml from ISR.utils.logger import get_logger logger = get_logger(__name__) def _get_parser(): parser = argparse.ArgumentParser() parser.add_argument('--prediction', action='store_true', dest='prediction') parser.add_argument('--training', action='store_true', dest='training') parser.add_argument('--summary', action='store_true', dest='summary') parser.add_argument('--default', action='store_true', dest='default') parser.add_argument('--config', action='store', dest='config_file') return parser def parse_args(): """ Parse CLI arguments. """ parser = _get_parser() args = vars(parser.parse_args()) if args['prediction'] and args['training']: logger.error('Select only prediction OR training.') raise ValueError('Select only prediction OR training.') return args def get_config_from_weights(w_path, arch_params, name):
def run(config_file, default=False, training=False, prediction=False): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' logger = get_logger(__name__) session_type, generator, conf, dataset = setup(config_file, default, training, prediction) lr_patch_size = conf['session'][session_type]['patch_size'] scale = conf['generators'][generator]['x'] module = _get_module(generator) gen = module.make_model(conf['generators'][generator], lr_patch_size) if session_type == 'prediction': from ISR.predict.predictor import Predictor pr_h = Predictor(input_dir=conf['test_sets'][dataset]) pr_h.get_predictions(gen, conf['weights_paths']['generator']) elif session_type == 'training': from ISR.train.trainer import Trainer hr_patch_size = lr_patch_size * scale if conf['default']['feat_ext']: from ISR.models.cut_vgg19 import Cut_VGG19 out_layers = conf['feat_extr']['vgg19']['layers_to_extract'] f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers) else: f_ext = None if conf['default']['discriminator']: from ISR.models.discriminator import Discriminator discr = Discriminator(patch_size=hr_patch_size, kernel_size=3) else: discr = None trainer = Trainer( generator=gen, discriminator=discr, feature_extractor=f_ext, lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'], hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'], lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'], hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'], loss_weights=conf['loss_weights'], dataname=conf['training_sets'][dataset]['data_name'], logs_dir=conf['dirs']['logs'], weights_dir=conf['dirs']['weights'], weights_generator=conf['weights_paths']['generator'], weights_discriminator=conf['weights_paths']['discriminator'], n_validation=conf['session'][session_type]['n_validation_samples'], lr_decay_frequency=conf['session'][session_type]['lr_decay_frequency'], lr_decay_factor=conf['session'][session_type]['lr_decay_factor'], T=0.01, ) trainer.train( epochs=conf['session'][session_type]['epochs'], steps_per_epoch=conf['session'][session_type]['steps_per_epoch'], batch_size=conf['session'][session_type]['batch_size'], ) else: logger.error('Invalid choice.')
def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, learning_rate=0.0004, loss_weights={'MSE': 1.0}, logs_dir='logs', weights_dir='weights', dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, T=0.01, lr_decay_frequency=100, lr_decay_factor=0.5, ): if discriminator: assert generator.patch_size * generator.scale == discriminator.patch_size if feature_extractor: assert generator.patch_size * generator.scale == feature_extractor.patch_size self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.best_metrics = {} self.pretrained_weights_path = { 'generator': weights_generator, 'discriminator': weights_discriminator, } self.lr_decay_factor = lr_decay_factor self.lr_decay_frequency = lr_decay_frequency self.helper = TrainerHelper( generator=self.generator, weights_dir=weights_dir, logs_dir=logs_dir, lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, pretrained_weights_path=self.pretrained_weights_path, ) self.model = self._combine_networks() self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, T=T, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, T=0.01, ) self.logger = get_logger(__name__)