def __init__(self, scale, channel, **kwargs): args = Config(kwargs) args.scale = [scale] args.n_colors = channel self.rgb_range = args.rgb_range self.ran = ran2.RAN(args) super(RAN, self).__init__(channel=channel, scale=scale, **kwargs)
def __init__(self, scale, channel, **kwargs): args = Config(kwargs) args.scale = [scale] args.n_colors = channel self.rgb_range = args.rgb_range self.frn = frn.FRN_UPDOWN(args) super(FRN, self).__init__(channel=channel, scale=scale, **kwargs)
def __init__(self, scale, **kwargs): super(MSRN, self).__init__(scale, 3) args = Config(kwargs) args.scale = [scale] self.rgb_range = args.rgb_range self.msrn = msrn.MSRN(args) self.opt = torch.optim.Adam(self.trainable_variables(), 1e-4)
def __init__(self, scale, **kwargs): super(EDSR, self).__init__(scale, 3) args = Config(kwargs) args.scale = [scale] self.rgb_range = args.rgb_range self.edsr = edsr.EDSR(args) self.opt = torch.optim.Adam(self.trainable_variables(), 1e-4)
def __init__(self, scale, channel, **kwargs): args = Config(kwargs) args.scale = [scale] args.n_colors = channel self.rgb_range = args.rgb_range self.edrn = edrn.EDRN(args) super(EDRN, self).__init__(channel=channel, scale=scale, **kwargs)
def dummy_test_config(): d = Config(a=1, b=2) d.update(a=2, b=3) d.a = 9 d.update(Config(b=6, f=5)) d.pop('b') print(d)
def __init__(self, scale, channel, **kwargs): super(RAN, self).__init__(channel=channel, scale=scale) args = Config(kwargs) args.scale = [scale] args.n_colors = channel self.rgb_range = args.rgb_range self.ran = ran2.RAN(args) self.opt = torch.optim.Adam(self.trainable_variables(), 1e-4)
def __init__(self, scale, channel, **kwargs): super(EDRN, self).__init__(scale, channel) args = Config(kwargs) args.scale = [scale] args.n_colors = channel self.rgb_range = args.rgb_range self.edrn = edrn.EDRN(args) self.adam = torch.optim.Adam(self.trainable_variables(), 1e-4)
def check_args(opt): if opt.c: opt.update(Config(opt.c)) _required = ('model',) for r in _required: if r not in opt or not opt.get(r): raise ValueError('--' + r + ' must be set')
def __init__(self, name='gan', patch_size=32, z_dim=128, init_filter=512, linear=False, norm_g=None, norm_d=None, use_bias=False, optimizer=None, arch=None, nd_iter=1, **kwargs): super(GAN, self).__init__(**kwargs) self.name = name self._trainer = GanTrainer self.output_size = patch_size self.z_dim = z_dim self.init_filter = init_filter self.linear = linear self.bias = use_bias self.nd_iter = nd_iter if isinstance(norm_g, str): self.bn = np.any([word in norm_g for word in ('bn', 'batch')]) self.sn = np.any([word in norm_g for word in ('sn', 'spectral')]) self.d_outputs = [] # (real, fake) self.g_outputs = [] # (real, fake) self.opt = optimizer if self.opt is None: self.opt = Config(name='adam') if arch is None or arch == 'dcgan': self.G = self.dcgan_g self.D = Discriminator.dcgan_d( self, [patch_size, patch_size, self.channel], norm=norm_d, name_or_scope='D') elif arch == 'resnet': self.G = self.resnet_g self.D = Discriminator.resnet_d( self, [patch_size, patch_size, self.channel], times_pooling=4, norm=norm_d, name_or_scope='D')
def query_config(self, config, **kwargs): config = Config(config or {}) config.update(kwargs) self.v.epochs = config.epochs or 1 # total epochs self.v.batch_shape = config.batch_shape or [1, -1, -1, -1] self.v.steps = config.steps or 200 self.v.val_steps = config.val_steps or -1 self.v.lr = config.lr or 1e-4 # learning rate self.v.lr_schedule = config.lr_schedule self.v.memory_limit = config.memory_limit self.v.inference_results_hooks = config.inference_results_hooks or [] self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1 self.v.traced_val = config.traced_val self.v.ensemble = config.ensemble self.v.cuda = config.cuda self.v.caching = config.caching_dataset return self.v
def query_config(self, config, **kwargs): config = Config(config or {}) config.update(kwargs) self.v.epochs = config.epochs or 1 # total epochs self.v.batch_shape = config.batch_shape or [1, -1, -1, -1] self.v.steps = config.steps or 200 self.v.val_steps = config.val_steps or -1 self.v.lr = config.lr or 1e-4 # learning rate self.v.lr_schedule = config.lr_schedule self.v.memory_limit = config.memory_limit self.v.inference_results_hooks = config.inference_results_hooks or [] self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1 self.v.traced_val = config.traced_val self.v.ensemble = config.ensemble self.v.cuda = config.cuda self.v.map_location = 'cuda:0' if config.cuda and torch.cuda.is_available( ) else 'cpu' return self.v
def main(*args): if not opt.input_dir: raise ValueError("--input_dir is required") if not opt.dataset.upper() in DATASETS.keys(): raise ValueError("--dataset is missing, or can't be found") data_ref = DATASETS.get(opt.dataset.upper()) data = load_folder(opt.input_dir) skip = opt.offset metric_config = Config(depth=opt.clip, batch=1, scale=1, modcrop=False) loader = BasicLoader(data, 'test', metric_config, False) ref_loader = BasicLoader(data_ref, 'test', metric_config, False) # make sure len(ref_loader) == len(loader) loader_iter = loader.make_one_shot_iterator() ref_iter = ref_loader.make_one_shot_iterator() for ref, _, name in ref_iter: name = str(name) img, _, _ = next(loader_iter) # reduce the batch dimension for video clips if img.ndim == 5: img = img[0] if ref.ndim == 5: ref = ref[0] if opt.shave: img = shave(img, opt.shave) ref = shave(ref, opt.shave) if opt.l_only: img = rgb_to_yuv(img, max_val=255, standard=opt.l_standard)[..., 0:1] ref = rgb_to_yuv(ref, max_val=255, standard=opt.l_standard)[..., 0:1] if ref.shape[0] - skip != img.shape[0]: b_min = np.minimum(ref.shape[0] - skip, img.shape[0]) ref = ref[:b_min + skip, ...] img = img[:b_min, ...] img = tf.constant(img.astype(np.float32)) ref = tf.constant(ref.astype(np.float32)) psnr = tf.reduce_mean(tf.image.psnr( ref[skip:], img, 255)).eval() if not opt.no_psnr else 0 ssim = tf.reduce_mean(tf.image.ssim( ref[skip:], img, 255)).eval() if not opt.no_ssim else 0 tf.logging.info(f'[{name}] PSNR = {psnr}, SSIM = {ssim}') tf.add_to_collection('PSNR', psnr) tf.add_to_collection('SSIM', ssim) for key in ('PSNR', 'SSIM'): mp = np.mean(tf.get_collection(key)) tf.logging.info(f'Mean {key}: {mp}')
def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat # apply a 2-stage (or master-slave) configuration, master can be # override by slave model_config_root = Path('Parameters/root.{}'.format(_ext)) if opt.p: model_config_file = Path(opt.p) else: model_config_file = Path('Parameters/{}.{}'.format(opt.model, _ext)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model, {}) suppress_opt_by_args(model_params, *args) opt.update(model_params) model = get_model(flags.model)(**model_params) if flags.cuda: model.cuda() root = f'{flags.save_dir}/{flags.model}' if flags.comment: root += '_' + flags.comment verbosity = logging.DEBUG if flags.verbose else logging.INFO trainer = model.trainer datasets = load_datasets(data_config_file) try: test_datas = [datasets[t.upper()] for t in flags.test] run_benchmark = True except KeyError: test_datas = [] for pattern in flags.test: test_data = Dataset(test=_glob_absolute_pattern(pattern), mode='pil-image1', modcrop=False) father = Path(flags.test) while not father.is_dir(): if father.parent == father: break father = father.parent test_data.name = father.stem test_datas.append(test_data) run_benchmark = False if opt.verbose: dump(opt) for test_data in test_datas: loader_config = Config(convert_to='rgb', feature_callbacks=[], label_callbacks=[], output_callbacks=[], **opt) loader_config.batch = 1 loader_config.subdir = test_data.name loader_config.output_callbacks += [ save_image(root, flags.output_index, flags.auto_rename)] if opt.channel == 1: loader_config.convert_to = 'gray' with trainer(model, root, verbosity, flags.pth) as t: if flags.seed is not None: t.set_seed(flags.seed) loader = QuickLoader(test_data, 'test', loader_config, n_threads=flags.thread) loader_config.epoch = flags.epoch if run_benchmark: t.benchmark(loader, loader_config) else: t.infer(loader, loader_config)
def init_loader_config(opt): train_config = Config(**opt, crop='random', feature_callbacks=[], label_callbacks=[]) benchmark_config = Config(**opt, crop=None, feature_callbacks=[], label_callbacks=[], output_callbacks=[]) infer_config = Config(**opt, feature_callbacks=[], label_callbacks=[], output_callbacks=[]) benchmark_config.batch = opt.test_batch or 1 benchmark_config.steps_per_epoch = -1 if opt.channel == 1: train_config.convert_to = 'gray' benchmark_config.convert_to = 'gray' if opt.output_color == 'RGB': benchmark_config.convert_to = 'yuv' benchmark_config.feature_callbacks = train_config.feature_callbacks + [to_gray()] benchmark_config.label_callbacks = train_config.label_callbacks + [to_gray()] benchmark_config.output_callbacks = [to_rgb()] benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)] infer_config.update(benchmark_config) else: train_config.convert_to = 'rgb' benchmark_config.convert_to = 'rgb' benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)] infer_config.update(benchmark_config) if opt.add_custom_callbacks is not None: for fn in opt.add_custom_callbacks: train_config.feature_callbacks += [globals()[fn]] benchmark_config.feature_callbacks += [globals()[fn]] infer_config.feature_callbacks += [globals()[fn]] if opt.lr_decay: train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) # modcrop: A boolean to specify whether to crop the edge of images to be divisible # by `scale`. It's useful when to provide batches with original shapes. infer_config.modcrop = False return train_config, benchmark_config, infer_config
def main(*args): flags = tf.flags.FLAGS opt = Config() for key in flags: opt.setdefault(key, flags.get_flag_value(key, None)) check_args(opt) data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _suffix in ('json', 'yaml'): # for compatibility # apply a 2-stage (or master-slave) configuration, master can be override by slave model_config_root = Path('parameters/{}.{}'.format('root', _suffix)) if opt.p: model_config_file = Path(opt.p) else: model_config_file = Path('parameters/{}.{}'.format(opt.model, _suffix)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model) opt.update(model_params) model = get_model(opt.model)(**model_params) root = '{}/{}'.format(opt.save_dir, model.name) if opt.comment: root += '_' + opt.comment opt.root = root verbosity = tf.logging.DEBUG if opt.v else tf.logging.INFO # map model to trainer, ~~manually~~ automatically, by setting _trainer attribute in models trainer = model.trainer train_data, test_data, infer_data = fetch_datasets(data_config_file, opt) train_config, test_config, infer_config = init_loader_config(opt) test_config.subdir = test_data.name infer_config.subdir = 'infer' # start fitting! dump(opt) with trainer(model, root, verbosity) as t: # prepare loader loader = partial(QuickLoader, n_threads=opt.threads) train_loader = loader(train_data, 'train', train_config, augmentation=True) val_loader = loader(train_data, 'val', train_config, crop='center', steps_per_epoch=1) test_loader = loader(test_data, 'test', test_config) infer_loader = loader(infer_data, 'infer', infer_config) # fit t.fit([train_loader, val_loader], train_config) # validate t.benchmark(test_loader, test_config) # do inference t.infer(infer_loader, infer_config) if opt.export: t.export(opt.root + '/exported', opt.freeze)
def main(): flags, args = parser.parse_known_args() opt = Config() for pair in flags._get_kwargs(): opt.setdefault(*pair) data_config_file = Path(flags.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") for _ext in ('json', 'yaml', 'yml'): # for compat # apply a 2-stage (or master-slave) configuration, master can be # override by slave model_config_root = Path('Parameters/root.{}'.format(_ext)) if opt.p: model_config_file = Path(opt.p) else: model_config_file = Path('Parameters/{}.{}'.format( opt.model, _ext)) if model_config_root.exists(): opt.update(Config(str(model_config_root))) if model_config_file.exists(): opt.update(Config(str(model_config_file))) model_params = opt.get(opt.model, {}) opt.update(model_params) suppress_opt_by_args(model_params, *args) model = get_model(flags.model)(**model_params) if flags.cuda: model.cuda() root = f'{flags.save_dir}/{flags.model}' if flags.comment: root += '_' + flags.comment verbosity = logging.DEBUG if flags.verbose else logging.INFO trainer = model.trainer datasets = load_datasets(data_config_file) dataset = datasets[flags.dataset.upper()] train_config = Config(crop=opt.train_data_crop, feature_callbacks=[], label_callbacks=[], convert_to='rgb', **opt) if opt.channel == 1: train_config.convert_to = 'gray' if opt.lr_decay: train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay) train_config.random_val = not opt.traced_val train_config.cuda = flags.cuda if opt.verbose: dump(opt) with trainer(model, root, verbosity, opt.pth) as t: if opt.seed is not None: t.set_seed(opt.seed) tloader = QuickLoader(dataset, 'train', train_config, True, flags.thread) vloader = QuickLoader(dataset, 'val', train_config, False, flags.thread, batch=1, crop=opt.val_data_crop, steps_per_epoch=opt.val_num) t.fit([tloader, vloader], train_config) if opt.export: t.export(opt.export)
class SRTrainer(Env): v = Config() def query_config(self, config, **kwargs): assert isinstance(config, Config) config.update(kwargs) self.v.epoch = config.epoch # current epoch self.v.epochs = config.epochs # total epochs self.v.lr = config.lr # learning rate self.v.lr_schedule = config.lr_schedule self.v.memory_limit = config.memory_limit self.v.feature_callbacks = config.feature_callbacks or [] self.v.label_callbacks = config.label_callbacks or [] self.v.output_callbacks = config.output_callbacks or [] self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1 self.v.subdir = config.subdir self.v.random_val = config.random_val self.v.ensemble = config.ensemble self.v.cuda = config.cuda self.v.map_location = 'cuda:0' if config.cuda and torch.cuda.is_available() else 'cpu' return self.v def fit_init(self) -> bool: v = self.v v.epoch = self._restore() if v.epoch >= v.epochs: self._logger.info(f'Found pre-trained epoch {v.epoch}>=target {v.epochs},' ' quit fitting.') return False self._logger.info('Fitting: {}'.format(self.model.name.upper())) v.writer = Summarizer(str(self._logd), self.model.name) return True def fit_close(self): # flush all pending summaries to disk if isinstance(self.v.writer, Summarizer): self.v.writer.close() self._logger.info(f'Training {self.model.name.upper()} finished.') def fit(self, loaders, config, **kwargs): v = self.query_config(config, **kwargs) v.train_loader, v.val_loader = loaders if not self.fit_init(): return mem = v.memory_limit for epoch in range(self.last_epoch + 1, v.epochs + 1): v.epoch = epoch train_iter = v.train_loader.make_one_shot_iterator(mem, shuffle=True) if hasattr(v.train_loader, 'prefetch'): v.train_loader.prefetch(mem) date = time.strftime('%Y-%m-%d %T', time.localtime()) v.avg_meas = {} if v.lr_schedule and callable(v.lr_schedule): v.lr = v.lr_schedule(steps=v.epoch) print('| {} | Epoch: {}/{} | LR: {:.2g} |'.format( date, v.epoch, v.epochs, v.lr)) with tqdm.tqdm(train_iter, unit='batch', ascii=True) as r: self.model.to_train() for items in r: label, feature, name, post = items[:4] self.fn_train_each_step(label, feature, name, post) r.set_postfix(v.loss) for _k, _v in v.avg_meas.items(): _v = np.mean(_v) if isinstance(self.v.writer, Summarizer): v.writer.scalar(_k, _v, step=v.epoch, collection='train') print('| Epoch average {} = {:.6f} |'.format(_k, _v)) if v.epoch % v.validate_every_n_epoch == 0: # Hard-coded memory limitation for validating self.benchmark(v.val_loader, v, memory_limit='1GB') self._save_model(v.epoch) self.fit_close() def fn_train_each_step(self, label=None, feature=None, name=None, post=None): v = self.v for fn in v.feature_callbacks: feature = fn(feature, name=name) for fn in v.label_callbacks: label = fn(label, name=name) feature = to_tensor(feature, v.cuda) label = to_tensor(label, v.cuda) loss = self.model.train([feature], [label], v.lr) for _k, _v in loss.items(): v.avg_meas[_k] = \ v.avg_meas[_k] + [_v] if v.avg_meas.get(_k) else [_v] loss[_k] = '{:08.5f}'.format(_v) v.loss = loss def benchmark(self, loader, config, **kwargs): """Benchmark/validate the model. Args: loader: a loader for enumerating LR images config: benchmark configuration, an instance of `Util.Config.Config` kwargs: additional arguments to override the same ones in config. """ v = self.query_config(config, **kwargs) v.color_format = loader.color_format self._restore(config.epoch, v.map_location) v.mean_metrics = {} v.loader = loader it = v.loader.make_one_shot_iterator(v.memory_limit, shuffle=v.random_val) self.model.to_eval() for items in tqdm.tqdm(it, 'Test', ascii=True): label, feature, name, post = items[:4] with torch.no_grad(): self.fn_benchmark_each_step(label, feature, name, post) for _k, _v in v.mean_metrics.items(): _v = np.mean(_v) if isinstance(self.v.writer, Summarizer): v.writer.scalar(_k, _v, step=v.epoch, collection='eval') print('{}: {:.6f}'.format(_k, _v), end=', ') print('') def fn_benchmark_each_step(self, label=None, feature=None, name=None, post=None): v = self.v origin_feat = feature for fn in v.feature_callbacks: feature = fn(feature, name=name) for fn in v.label_callbacks: label = fn(label, name=name) feature = to_tensor(feature, v.cuda) label = to_tensor(label, v.cuda) with torch.set_grad_enabled(False): outputs, metrics = self.model.eval([feature], [label], epoch=v.epoch) for _k, _v in metrics.items(): if _k not in v.mean_metrics: v.mean_metrics[_k] = [] v.mean_metrics[_k] += [_v] outputs = [from_tensor(x) for x in outputs] for fn in v.output_callbacks: outputs = fn(outputs, input=origin_feat, label=label, name=name, mode=v.color_format, subdir=v.subdir) def infer(self, loader, config, **kwargs): """Infer SR images. Args: loader: a loader for enumerating LR images config: inferring configuration, an instance of `Util.Config.Config` kwargs: additional arguments to override the same ones in config. """ v = self.query_config(config, **kwargs) v.color_format = loader.color_format self._restore(config.epoch, v.map_location) it = loader.make_one_shot_iterator() if hasattr(it, 'len'): if len(it): self._logger.info('Inferring {} at epoch {}'.format( self.model.name, self.last_epoch)) else: return # use original images in inferring self.model.to_eval() for items in tqdm.tqdm(it, 'Infer', ascii=True): feature = items[0] name = items[2] with torch.no_grad(): self.fn_infer_each_step(None, feature, name) def fn_infer_each_step(self, label=None, feature=None, name=None, post=None): v = self.v origin_feat = feature for fn in v.feature_callbacks: feature = fn(feature, name=name) with torch.set_grad_enabled(False): if v.ensemble: # add self-ensemble boosting metric score feature_ensemble = _ensemble_expand(feature) outputs_ensemble = [] for f in feature_ensemble: f = to_tensor(f, v.cuda) y, _ = self.model.eval([f]) y = [from_tensor(x) for x in y] outputs_ensemble.append(y) outputs = [] for i in range(len(outputs_ensemble[0])): outputs.append([j[i] for j in outputs_ensemble]) outputs = _ensemble_reduce_mean(outputs) else: feature = to_tensor(feature, v.cuda) outputs, _ = self.model.eval([feature]) outputs = [from_tensor(x) for x in outputs] for fn in v.output_callbacks: outputs = fn(outputs, input=origin_feat, name=name, subdir=v.subdir, mode=v.color_format)
def main(*args, **kwargs): flags = tf.flags.FLAGS check_args(flags) opt = Config() for key in flags: opt.setdefault(key, flags.get_flag_value(key, None)) opt.steps_per_epoch = opt.num # set random seed at first np.random.seed(opt.seed) # check output dir output_dir = Path(flags.save_dir) output_dir.mkdir(exist_ok=True, parents=True) writer = tf.io.TFRecordWriter( str(output_dir / "{}.tfrecords".format(opt.dataset))) data_config_file = Path(opt.data_config) if not data_config_file.exists(): raise RuntimeError("dataset config file doesn't exist!") crf_matrix = np.load(opt.crf) if opt.crf else None # init loader config train_data, _, _ = Run.fetch_datasets(data_config_file, opt) train_config, _, _ = Run.init_loader_config(opt) loader = QuickLoader(train_data, opt.method, train_config, n_threads=opt.threads, augmentation=opt.augment) it = loader.make_one_shot_iterator(opt.memory_limit, shuffle=True) with tqdm.tqdm(it, unit='batch', ascii=True) as r: for items in r: label, feature, names = items[:3] # label is usually HR image, feature is usually LR image batch_label = np.split(label, label.shape[0]) batch_feature = np.split(feature, feature.shape[0]) batch_name = np.split(names, names.shape[0]) for hr, lr, name in zip(batch_label, batch_feature, batch_name): hr = np.squeeze(hr) lr = np.squeeze(lr) name = np.squeeze(name) with io.BytesIO() as fp: Image.fromarray(hr, 'RGB').save(fp, format='png') fp.seek(0) hr_png = fp.read() with io.BytesIO() as fp: Image.fromarray(lr, 'RGB').save(fp, format='png') fp.seek(0) lr_png = fp.read() lr_post = process(lr, crf_matrix, (opt.sigma[0], opt.sigma[1])) with io.BytesIO() as fp: if opt.jpeg_quality: Image.fromarray(lr_post, 'RGB').save(fp, format='jpeg', quality=opt.jpeg_quality) else: Image.fromarray(lr_post, 'RGB').save(fp, format='png') fp.seek(0) post_png = fp.read() label = "{}_{}_{}".format(*name).encode() make_tensor_label_records( [hr_png, lr_png, label, post_png], ["image/hr", "image/lr", "name", "image/post"], writer)
class SRTrainer(Env): v = Config() def query_config(self, config, **kwargs): config = Config(config or {}) config.update(kwargs) self.v.epochs = config.epochs or 1 # total epochs self.v.batch_shape = config.batch_shape or [1, -1, -1, -1] self.v.steps = config.steps or 200 self.v.val_steps = config.val_steps or -1 self.v.lr = config.lr or 1e-4 # learning rate self.v.lr_schedule = config.lr_schedule self.v.memory_limit = config.memory_limit self.v.inference_results_hooks = config.inference_results_hooks or [] self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1 self.v.traced_val = config.traced_val self.v.ensemble = config.ensemble self.v.cuda = config.cuda self.v.map_location = 'cuda:0' if config.cuda and torch.cuda.is_available( ) else 'cpu' self.v.caching = config.caching return self.v def fit_init(self) -> bool: v = self.v v.epoch = self._restore() if v.epoch >= v.epochs: LOG.info(f'Found pre-trained epoch {v.epoch}>=target {v.epochs},' ' quit fitting.') return False LOG.info(f'Fitting: {self.model.name.upper()}') if self._logd: v.writer = Summarizer(str(self._logd), self.model.name) return True def fit_close(self): # flush all pending summaries to disk if isinstance(self.v.writer, Summarizer): self.v.writer.close() LOG.info(f'Training {self.model.name.upper()} finished.') def fit(self, loaders, config, **kwargs): v = self.query_config(config, **kwargs) v.train_loader, v.val_loader = loaders if not self.fit_init(): return mem = v.memory_limit for epoch in range(self.last_epoch + 1, v.epochs + 1): v.epoch = epoch train_iter = v.train_loader.make_one_shot_iterator( v.batch_shape, v.steps, shuffle=True, memory_limit=mem, caching=v.caching) v.train_loader.prefetch(shuffle=True, memory_usage=mem) v.avg_meas = {} if v.lr_schedule and callable(v.lr_schedule): v.lr = v.lr_schedule(steps=v.epoch) LOG.info(f"| Epoch: {v.epoch}/{v.epochs} | LR: {v.lr:.2g} |") with tqdm.tqdm(train_iter, unit='batch', ascii=True) as r: self.model.to_train() for items in r: self.fn_train_each_step(items) r.set_postfix(v.loss) for _k, _v in v.avg_meas.items(): _v = np.mean(_v) if isinstance(self.v.writer, Summarizer): v.writer.scalar(_k, _v, step=v.epoch, collection='train') LOG.info(f"| Epoch average {_k} = {_v:.6f} |") if v.epoch % v.validate_every_n_epoch == 0 and v.val_loader: # Hard-coded memory limitation for validating self.benchmark(v.val_loader, v, memory_limit='1GB') self._save_model(v.epoch) self.fit_close() def fn_train_each_step(self, pack): v = self.v feature = to_tensor(pack['lr'], v.cuda) label = to_tensor(pack['hr'], v.cuda) loss = self.model.train([feature], [label], v.lr) for _k, _v in loss.items(): v.avg_meas[_k] = \ v.avg_meas[_k] + [_v] if v.avg_meas.get(_k) else [_v] loss[_k] = '{:08.5f}'.format(_v) v.loss = loss def benchmark(self, loader, config, **kwargs): """Benchmark/validate the model. Args: loader: a loader for enumerating LR images config: benchmark configuration, an instance of `Util.Config.Config` kwargs: additional arguments to override the same ones in config. """ v = self.query_config(config, **kwargs) self._restore(config.epoch, v.map_location) v.mean_metrics = {} v.loader = loader it = v.loader.make_one_shot_iterator(v.batch_shape, v.val_steps, shuffle=not v.traced_val, memory_limit=v.memory_limit, caching=v.caching) self.model.to_eval() for items in tqdm.tqdm(it, 'Test', ascii=True): with torch.no_grad(): self.fn_benchmark_each_step(items) log_message = str() for _k, _v in v.mean_metrics.items(): _v = np.mean(_v) if isinstance(self.v.writer, Summarizer): v.writer.scalar(_k, _v, step=v.epoch, collection='eval') log_message += f"{_k}: {_v:.6f}, " log_message = log_message[:-2] + "." LOG.info(log_message) def fn_benchmark_each_step(self, pack): v = self.v feature = to_tensor(pack['lr'], v.cuda) label = to_tensor(pack['hr'], v.cuda) with torch.set_grad_enabled(False): outputs, metrics = self.model.eval([feature], [label], epoch=v.epoch) for _k, _v in metrics.items(): if _k not in v.mean_metrics: v.mean_metrics[_k] = [] v.mean_metrics[_k] += [_v] outputs = [from_tensor(x) for x in outputs] for fn in v.inference_results_hooks: outputs = fn(outputs, names=pack['name']) if outputs is None: break def infer(self, loader, config, **kwargs): """Infer SR images. Args: loader: a loader for enumerating LR images config: inferring configuration, an instance of `Util.Config.Config` kwargs: additional arguments to override the same ones in config. """ v = self.query_config(config, **kwargs) self._restore(config.epoch, v.map_location) it = loader.make_one_shot_iterator(v.batch_shape, -1) if hasattr(it, '__len__'): if len(it) == 0: return LOG.info(f"Inferring {self.model.name} at epoch {self.last_epoch}") # use original images in inferring self.model.to_eval() for items in tqdm.tqdm(it, 'Infer', ascii=True): with torch.no_grad(): self.fn_infer_each_step(items) def fn_infer_each_step(self, pack): v = self.v with torch.set_grad_enabled(False): if v.ensemble: # add self-ensemble boosting metric score feature_ensemble = Ensembler.expand(pack['lr']) outputs_ensemble = [] for f in feature_ensemble: f = to_tensor(f, v.cuda) y, _ = self.model.eval([f]) y = [from_tensor(x) for x in y] outputs_ensemble.append(y) outputs = [] for i in range(len(outputs_ensemble[0])): outputs.append([j[i] for j in outputs_ensemble]) outputs = Ensembler.merge(outputs) else: feature = to_tensor(pack['lr'], v.cuda) outputs, _ = self.model.eval([feature]) outputs = [from_tensor(x) for x in outputs] for fn in v.inference_results_hooks: outputs = fn(outputs, names=pack['name']) if outputs is None: break