def __init__(self, scale, channel, **kwargs):
     args = Config(kwargs)
     args.scale = [scale]
     args.n_colors = channel
     self.rgb_range = args.rgb_range
     self.ran = ran2.RAN(args)
     super(RAN, self).__init__(channel=channel, scale=scale, **kwargs)
 def __init__(self, scale, channel, **kwargs):
     args = Config(kwargs)
     args.scale = [scale]
     args.n_colors = channel
     self.rgb_range = args.rgb_range
     self.frn = frn.FRN_UPDOWN(args)
     super(FRN, self).__init__(channel=channel, scale=scale, **kwargs)
Beispiel #3
0
 def __init__(self, scale, **kwargs):
     super(MSRN, self).__init__(scale, 3)
     args = Config(kwargs)
     args.scale = [scale]
     self.rgb_range = args.rgb_range
     self.msrn = msrn.MSRN(args)
     self.opt = torch.optim.Adam(self.trainable_variables(), 1e-4)
Beispiel #4
0
 def __init__(self, scale, **kwargs):
   super(EDSR, self).__init__(scale, 3)
   args = Config(kwargs)
   args.scale = [scale]
   self.rgb_range = args.rgb_range
   self.edsr = edsr.EDSR(args)
   self.opt = torch.optim.Adam(self.trainable_variables(), 1e-4)
 def __init__(self, scale, channel, **kwargs):
     args = Config(kwargs)
     args.scale = [scale]
     args.n_colors = channel
     self.rgb_range = args.rgb_range
     self.edrn = edrn.EDRN(args)
     super(EDRN, self).__init__(channel=channel, scale=scale, **kwargs)
def dummy_test_config():
    d = Config(a=1, b=2)
    d.update(a=2, b=3)
    d.a = 9
    d.update(Config(b=6, f=5))
    d.pop('b')
    print(d)
Beispiel #7
0
 def __init__(self, scale, channel, **kwargs):
     super(RAN, self).__init__(channel=channel, scale=scale)
     args = Config(kwargs)
     args.scale = [scale]
     args.n_colors = channel
     self.rgb_range = args.rgb_range
     self.ran = ran2.RAN(args)
     self.opt = torch.optim.Adam(self.trainable_variables(), 1e-4)
Beispiel #8
0
 def __init__(self, scale, channel, **kwargs):
     super(EDRN, self).__init__(scale, channel)
     args = Config(kwargs)
     args.scale = [scale]
     args.n_colors = channel
     self.rgb_range = args.rgb_range
     self.edrn = edrn.EDRN(args)
     self.adam = torch.optim.Adam(self.trainable_variables(), 1e-4)
Beispiel #9
0
def check_args(opt):
    if opt.c:
        opt.update(Config(opt.c))
    _required = ('model',)
    for r in _required:
        if r not in opt or not opt.get(r):
            raise ValueError('--' + r + ' must be set')
Beispiel #10
0
 def __init__(self, name='gan', patch_size=32, z_dim=128, init_filter=512,
              linear=False, norm_g=None, norm_d=None, use_bias=False,
              optimizer=None, arch=None, nd_iter=1, **kwargs):
     super(GAN, self).__init__(**kwargs)
     self.name = name
     self._trainer = GanTrainer
     self.output_size = patch_size
     self.z_dim = z_dim
     self.init_filter = init_filter
     self.linear = linear
     self.bias = use_bias
     self.nd_iter = nd_iter
     if isinstance(norm_g, str):
         self.bn = np.any([word in norm_g for word in ('bn', 'batch')])
         self.sn = np.any([word in norm_g for word in ('sn', 'spectral')])
     self.d_outputs = []  # (real, fake)
     self.g_outputs = []  # (real, fake)
     self.opt = optimizer
     if self.opt is None:
         self.opt = Config(name='adam')
     if arch is None or arch == 'dcgan':
         self.G = self.dcgan_g
         self.D = Discriminator.dcgan_d(
             self, [patch_size, patch_size, self.channel],
             norm=norm_d, name_or_scope='D')
     elif arch == 'resnet':
         self.G = self.resnet_g
         self.D = Discriminator.resnet_d(
             self, [patch_size, patch_size, self.channel], times_pooling=4,
             norm=norm_d, name_or_scope='D')
Beispiel #11
0
 def query_config(self, config, **kwargs):
     config = Config(config or {})
     config.update(kwargs)
     self.v.epochs = config.epochs or 1  # total epochs
     self.v.batch_shape = config.batch_shape or [1, -1, -1, -1]
     self.v.steps = config.steps or 200
     self.v.val_steps = config.val_steps or -1
     self.v.lr = config.lr or 1e-4  # learning rate
     self.v.lr_schedule = config.lr_schedule
     self.v.memory_limit = config.memory_limit
     self.v.inference_results_hooks = config.inference_results_hooks or []
     self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1
     self.v.traced_val = config.traced_val
     self.v.ensemble = config.ensemble
     self.v.cuda = config.cuda
     self.v.caching = config.caching_dataset
     return self.v
Beispiel #12
0
 def query_config(self, config, **kwargs):
     config = Config(config or {})
     config.update(kwargs)
     self.v.epochs = config.epochs or 1  # total epochs
     self.v.batch_shape = config.batch_shape or [1, -1, -1, -1]
     self.v.steps = config.steps or 200
     self.v.val_steps = config.val_steps or -1
     self.v.lr = config.lr or 1e-4  # learning rate
     self.v.lr_schedule = config.lr_schedule
     self.v.memory_limit = config.memory_limit
     self.v.inference_results_hooks = config.inference_results_hooks or []
     self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1
     self.v.traced_val = config.traced_val
     self.v.ensemble = config.ensemble
     self.v.cuda = config.cuda
     self.v.map_location = 'cuda:0' if config.cuda and torch.cuda.is_available(
     ) else 'cpu'
     return self.v
Beispiel #13
0
def main(*args):
    if not opt.input_dir:
        raise ValueError("--input_dir is required")
    if not opt.dataset.upper() in DATASETS.keys():
        raise ValueError("--dataset is missing, or can't be found")
    data_ref = DATASETS.get(opt.dataset.upper())
    data = load_folder(opt.input_dir)
    skip = opt.offset
    metric_config = Config(depth=opt.clip, batch=1, scale=1, modcrop=False)
    loader = BasicLoader(data, 'test', metric_config, False)
    ref_loader = BasicLoader(data_ref, 'test', metric_config, False)
    # make sure len(ref_loader) == len(loader)
    loader_iter = loader.make_one_shot_iterator()
    ref_iter = ref_loader.make_one_shot_iterator()
    for ref, _, name in ref_iter:
        name = str(name)
        img, _, _ = next(loader_iter)
        # reduce the batch dimension for video clips
        if img.ndim == 5: img = img[0]
        if ref.ndim == 5: ref = ref[0]
        if opt.shave:
            img = shave(img, opt.shave)
            ref = shave(ref, opt.shave)
        if opt.l_only:
            img = rgb_to_yuv(img, max_val=255, standard=opt.l_standard)[...,
                                                                        0:1]
            ref = rgb_to_yuv(ref, max_val=255, standard=opt.l_standard)[...,
                                                                        0:1]
        if ref.shape[0] - skip != img.shape[0]:
            b_min = np.minimum(ref.shape[0] - skip, img.shape[0])
            ref = ref[:b_min + skip, ...]
            img = img[:b_min, ...]
        img = tf.constant(img.astype(np.float32))
        ref = tf.constant(ref.astype(np.float32))
        psnr = tf.reduce_mean(tf.image.psnr(
            ref[skip:], img, 255)).eval() if not opt.no_psnr else 0
        ssim = tf.reduce_mean(tf.image.ssim(
            ref[skip:], img, 255)).eval() if not opt.no_ssim else 0
        tf.logging.info(f'[{name}] PSNR = {psnr}, SSIM = {ssim}')
        tf.add_to_collection('PSNR', psnr)
        tf.add_to_collection('SSIM', ssim)
    for key in ('PSNR', 'SSIM'):
        mp = np.mean(tf.get_collection(key))
        tf.logging.info(f'Mean {key}: {mp}')
Beispiel #14
0
def main():
  flags, args = parser.parse_known_args()
  opt = Config()
  for pair in flags._get_kwargs():
    opt.setdefault(*pair)
  data_config_file = Path(flags.data_config)
  if not data_config_file.exists():
    raise RuntimeError("dataset config file doesn't exist!")
  for _ext in ('json', 'yaml', 'yml'):  # for compat
    # apply a 2-stage (or master-slave) configuration, master can be
    # override by slave
    model_config_root = Path('Parameters/root.{}'.format(_ext))
    if opt.p:
      model_config_file = Path(opt.p)
    else:
      model_config_file = Path('Parameters/{}.{}'.format(opt.model, _ext))
    if model_config_root.exists():
      opt.update(Config(str(model_config_root)))
    if model_config_file.exists():
      opt.update(Config(str(model_config_file)))

  model_params = opt.get(opt.model, {})
  suppress_opt_by_args(model_params, *args)
  opt.update(model_params)
  model = get_model(flags.model)(**model_params)
  if flags.cuda:
    model.cuda()
  root = f'{flags.save_dir}/{flags.model}'
  if flags.comment:
    root += '_' + flags.comment
  verbosity = logging.DEBUG if flags.verbose else logging.INFO
  trainer = model.trainer

  datasets = load_datasets(data_config_file)
  try:
    test_datas = [datasets[t.upper()] for t in flags.test]
    run_benchmark = True
  except KeyError:
    test_datas = []
    for pattern in flags.test:
      test_data = Dataset(test=_glob_absolute_pattern(pattern),
                          mode='pil-image1', modcrop=False)
      father = Path(flags.test)
      while not father.is_dir():
        if father.parent == father:
          break
        father = father.parent
      test_data.name = father.stem
      test_datas.append(test_data)
    run_benchmark = False

  if opt.verbose:
    dump(opt)
  for test_data in test_datas:
    loader_config = Config(convert_to='rgb',
                           feature_callbacks=[], label_callbacks=[],
                           output_callbacks=[], **opt)
    loader_config.batch = 1
    loader_config.subdir = test_data.name
    loader_config.output_callbacks += [
      save_image(root, flags.output_index, flags.auto_rename)]
    if opt.channel == 1:
      loader_config.convert_to = 'gray'

    with trainer(model, root, verbosity, flags.pth) as t:
      if flags.seed is not None:
        t.set_seed(flags.seed)
      loader = QuickLoader(test_data, 'test', loader_config,
                           n_threads=flags.thread)
      loader_config.epoch = flags.epoch
      if run_benchmark:
        t.benchmark(loader, loader_config)
      else:
        t.infer(loader, loader_config)
Beispiel #15
0
def init_loader_config(opt):
    train_config = Config(**opt, crop='random', feature_callbacks=[], label_callbacks=[])
    benchmark_config = Config(**opt, crop=None, feature_callbacks=[], label_callbacks=[], output_callbacks=[])
    infer_config = Config(**opt, feature_callbacks=[], label_callbacks=[], output_callbacks=[])
    benchmark_config.batch = opt.test_batch or 1
    benchmark_config.steps_per_epoch = -1
    if opt.channel == 1:
        train_config.convert_to = 'gray'
        benchmark_config.convert_to = 'gray'
        if opt.output_color == 'RGB':
            benchmark_config.convert_to = 'yuv'
            benchmark_config.feature_callbacks = train_config.feature_callbacks + [to_gray()]
            benchmark_config.label_callbacks = train_config.label_callbacks + [to_gray()]
            benchmark_config.output_callbacks = [to_rgb()]
        benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)]
        infer_config.update(benchmark_config)
    else:
        train_config.convert_to = 'rgb'
        benchmark_config.convert_to = 'rgb'
        benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)]
        infer_config.update(benchmark_config)
    if opt.add_custom_callbacks is not None:
        for fn in opt.add_custom_callbacks:
            train_config.feature_callbacks += [globals()[fn]]
            benchmark_config.feature_callbacks += [globals()[fn]]
            infer_config.feature_callbacks += [globals()[fn]]
    if opt.lr_decay:
        train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
    # modcrop: A boolean to specify whether to crop the edge of images to be divisible
    #          by `scale`. It's useful when to provide batches with original shapes.
    infer_config.modcrop = False
    return train_config, benchmark_config, infer_config
Beispiel #16
0
def main(*args):
    flags = tf.flags.FLAGS
    opt = Config()
    for key in flags:
        opt.setdefault(key, flags.get_flag_value(key, None))
    check_args(opt)
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    for _suffix in ('json', 'yaml'):  # for compatibility
        # apply a 2-stage (or master-slave) configuration, master can be override by slave
        model_config_root = Path('parameters/{}.{}'.format('root', _suffix))
        if opt.p:
            model_config_file = Path(opt.p)
        else:
            model_config_file = Path('parameters/{}.{}'.format(opt.model, _suffix))
        if model_config_root.exists():
            opt.update(Config(str(model_config_root)))
        if model_config_file.exists():
            opt.update(Config(str(model_config_file)))

    model_params = opt.get(opt.model)
    opt.update(model_params)
    model = get_model(opt.model)(**model_params)
    root = '{}/{}'.format(opt.save_dir, model.name)
    if opt.comment:
        root += '_' + opt.comment
    opt.root = root
    verbosity = tf.logging.DEBUG if opt.v else tf.logging.INFO
    # map model to trainer, ~~manually~~ automatically, by setting _trainer attribute in models
    trainer = model.trainer
    train_data, test_data, infer_data = fetch_datasets(data_config_file, opt)
    train_config, test_config, infer_config = init_loader_config(opt)
    test_config.subdir = test_data.name
    infer_config.subdir = 'infer'
    # start fitting!
    dump(opt)
    with trainer(model, root, verbosity) as t:
        # prepare loader
        loader = partial(QuickLoader, n_threads=opt.threads)
        train_loader = loader(train_data, 'train', train_config, augmentation=True)
        val_loader = loader(train_data, 'val', train_config, crop='center', steps_per_epoch=1)
        test_loader = loader(test_data, 'test', test_config)
        infer_loader = loader(infer_data, 'infer', infer_config)
        # fit
        t.fit([train_loader, val_loader], train_config)
        # validate
        t.benchmark(test_loader, test_config)
        # do inference
        t.infer(infer_loader, infer_config)
        if opt.export:
            t.export(opt.root + '/exported', opt.freeze)
Beispiel #17
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)

    data_config_file = Path(flags.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        # apply a 2-stage (or master-slave) configuration, master can be
        # override by slave
        model_config_root = Path('Parameters/root.{}'.format(_ext))
        if opt.p:
            model_config_file = Path(opt.p)
        else:
            model_config_file = Path('Parameters/{}.{}'.format(
                opt.model, _ext))
        if model_config_root.exists():
            opt.update(Config(str(model_config_root)))
        if model_config_file.exists():
            opt.update(Config(str(model_config_file)))

    model_params = opt.get(opt.model, {})
    opt.update(model_params)
    suppress_opt_by_args(model_params, *args)
    model = get_model(flags.model)(**model_params)
    if flags.cuda:
        model.cuda()
    root = f'{flags.save_dir}/{flags.model}'
    if flags.comment:
        root += '_' + flags.comment
    verbosity = logging.DEBUG if flags.verbose else logging.INFO
    trainer = model.trainer

    datasets = load_datasets(data_config_file)
    dataset = datasets[flags.dataset.upper()]

    train_config = Config(crop=opt.train_data_crop,
                          feature_callbacks=[],
                          label_callbacks=[],
                          convert_to='rgb',
                          **opt)
    if opt.channel == 1:
        train_config.convert_to = 'gray'
    if opt.lr_decay:
        train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
    train_config.random_val = not opt.traced_val
    train_config.cuda = flags.cuda

    if opt.verbose:
        dump(opt)
    with trainer(model, root, verbosity, opt.pth) as t:
        if opt.seed is not None:
            t.set_seed(opt.seed)
        tloader = QuickLoader(dataset, 'train', train_config, True,
                              flags.thread)
        vloader = QuickLoader(dataset,
                              'val',
                              train_config,
                              False,
                              flags.thread,
                              batch=1,
                              crop=opt.val_data_crop,
                              steps_per_epoch=opt.val_num)
        t.fit([tloader, vloader], train_config)
        if opt.export:
            t.export(opt.export)
class SRTrainer(Env):
  v = Config()

  def query_config(self, config, **kwargs):
    assert isinstance(config, Config)
    config.update(kwargs)
    self.v.epoch = config.epoch  # current epoch
    self.v.epochs = config.epochs  # total epochs
    self.v.lr = config.lr  # learning rate
    self.v.lr_schedule = config.lr_schedule
    self.v.memory_limit = config.memory_limit
    self.v.feature_callbacks = config.feature_callbacks or []
    self.v.label_callbacks = config.label_callbacks or []
    self.v.output_callbacks = config.output_callbacks or []
    self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1
    self.v.subdir = config.subdir
    self.v.random_val = config.random_val
    self.v.ensemble = config.ensemble
    self.v.cuda = config.cuda
    self.v.map_location = 'cuda:0' if config.cuda and torch.cuda.is_available() else 'cpu'
    return self.v

  def fit_init(self) -> bool:
    v = self.v
    v.epoch = self._restore()
    if v.epoch >= v.epochs:
      self._logger.info(f'Found pre-trained epoch {v.epoch}>=target {v.epochs},'
                        ' quit fitting.')
      return False
    self._logger.info('Fitting: {}'.format(self.model.name.upper()))
    v.writer = Summarizer(str(self._logd), self.model.name)
    return True

  def fit_close(self):
    # flush all pending summaries to disk
    if isinstance(self.v.writer, Summarizer):
      self.v.writer.close()
    self._logger.info(f'Training {self.model.name.upper()} finished.')

  def fit(self, loaders, config, **kwargs):
    v = self.query_config(config, **kwargs)
    v.train_loader, v.val_loader = loaders
    if not self.fit_init():
      return
    mem = v.memory_limit
    for epoch in range(self.last_epoch + 1, v.epochs + 1):
      v.epoch = epoch
      train_iter = v.train_loader.make_one_shot_iterator(mem, shuffle=True)
      if hasattr(v.train_loader, 'prefetch'):
        v.train_loader.prefetch(mem)
      date = time.strftime('%Y-%m-%d %T', time.localtime())
      v.avg_meas = {}
      if v.lr_schedule and callable(v.lr_schedule):
        v.lr = v.lr_schedule(steps=v.epoch)
      print('| {} | Epoch: {}/{} | LR: {:.2g} |'.format(
        date, v.epoch, v.epochs, v.lr))
      with tqdm.tqdm(train_iter, unit='batch', ascii=True) as r:
        self.model.to_train()
        for items in r:
          label, feature, name, post = items[:4]
          self.fn_train_each_step(label, feature, name, post)
          r.set_postfix(v.loss)
      for _k, _v in v.avg_meas.items():
        _v = np.mean(_v)
        if isinstance(self.v.writer, Summarizer):
          v.writer.scalar(_k, _v, step=v.epoch, collection='train')
        print('| Epoch average {} = {:.6f} |'.format(_k, _v))
      if v.epoch % v.validate_every_n_epoch == 0:
        # Hard-coded memory limitation for validating
        self.benchmark(v.val_loader, v, memory_limit='1GB')
      self._save_model(v.epoch)
    self.fit_close()

  def fn_train_each_step(self, label=None, feature=None, name=None, post=None):
    v = self.v
    for fn in v.feature_callbacks:
      feature = fn(feature, name=name)
    for fn in v.label_callbacks:
      label = fn(label, name=name)
    feature = to_tensor(feature, v.cuda)
    label = to_tensor(label, v.cuda)
    loss = self.model.train([feature], [label], v.lr)
    for _k, _v in loss.items():
      v.avg_meas[_k] = \
        v.avg_meas[_k] + [_v] if v.avg_meas.get(_k) else [_v]
      loss[_k] = '{:08.5f}'.format(_v)
    v.loss = loss

  def benchmark(self, loader, config, **kwargs):
    """Benchmark/validate the model.

    Args:
        loader: a loader for enumerating LR images
        config: benchmark configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
    v = self.query_config(config, **kwargs)
    v.color_format = loader.color_format

    self._restore(config.epoch, v.map_location)
    v.mean_metrics = {}
    v.loader = loader
    it = v.loader.make_one_shot_iterator(v.memory_limit, shuffle=v.random_val)
    self.model.to_eval()
    for items in tqdm.tqdm(it, 'Test', ascii=True):
      label, feature, name, post = items[:4]
      with torch.no_grad():
        self.fn_benchmark_each_step(label, feature, name, post)
    for _k, _v in v.mean_metrics.items():
      _v = np.mean(_v)
      if isinstance(self.v.writer, Summarizer):
        v.writer.scalar(_k, _v, step=v.epoch, collection='eval')
      print('{}: {:.6f}'.format(_k, _v), end=', ')
    print('')

  def fn_benchmark_each_step(self, label=None, feature=None, name=None,
                             post=None):
    v = self.v
    origin_feat = feature
    for fn in v.feature_callbacks:
      feature = fn(feature, name=name)
    for fn in v.label_callbacks:
      label = fn(label, name=name)
    feature = to_tensor(feature, v.cuda)
    label = to_tensor(label, v.cuda)
    with torch.set_grad_enabled(False):
      outputs, metrics = self.model.eval([feature], [label], epoch=v.epoch)
    for _k, _v in metrics.items():
      if _k not in v.mean_metrics:
        v.mean_metrics[_k] = []
      v.mean_metrics[_k] += [_v]
    outputs = [from_tensor(x) for x in outputs]
    for fn in v.output_callbacks:
      outputs = fn(outputs, input=origin_feat, label=label, name=name,
                   mode=v.color_format, subdir=v.subdir)

  def infer(self, loader, config, **kwargs):
    """Infer SR images.

    Args:
        loader: a loader for enumerating LR images
        config: inferring configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
    v = self.query_config(config, **kwargs)
    v.color_format = loader.color_format

    self._restore(config.epoch, v.map_location)
    it = loader.make_one_shot_iterator()
    if hasattr(it, 'len'):
      if len(it):
        self._logger.info('Inferring {} at epoch {}'.format(
          self.model.name, self.last_epoch))
      else:
        return
    # use original images in inferring
    self.model.to_eval()
    for items in tqdm.tqdm(it, 'Infer', ascii=True):
      feature = items[0]
      name = items[2]
      with torch.no_grad():
        self.fn_infer_each_step(None, feature, name)

  def fn_infer_each_step(self, label=None, feature=None, name=None, post=None):
    v = self.v
    origin_feat = feature
    for fn in v.feature_callbacks:
      feature = fn(feature, name=name)
    with torch.set_grad_enabled(False):
      if v.ensemble:
        # add self-ensemble boosting metric score
        feature_ensemble = _ensemble_expand(feature)
        outputs_ensemble = []
        for f in feature_ensemble:
          f = to_tensor(f, v.cuda)
          y, _ = self.model.eval([f])
          y = [from_tensor(x) for x in y]
          outputs_ensemble.append(y)
        outputs = []
        for i in range(len(outputs_ensemble[0])):
          outputs.append([j[i] for j in outputs_ensemble])
        outputs = _ensemble_reduce_mean(outputs)
      else:
        feature = to_tensor(feature, v.cuda)
        outputs, _ = self.model.eval([feature])
        outputs = [from_tensor(x) for x in outputs]
    for fn in v.output_callbacks:
      outputs = fn(outputs, input=origin_feat, name=name, subdir=v.subdir,
                   mode=v.color_format)
def main(*args, **kwargs):
    flags = tf.flags.FLAGS
    check_args(flags)
    opt = Config()
    for key in flags:
        opt.setdefault(key, flags.get_flag_value(key, None))
    opt.steps_per_epoch = opt.num
    # set random seed at first
    np.random.seed(opt.seed)
    # check output dir
    output_dir = Path(flags.save_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    writer = tf.io.TFRecordWriter(
        str(output_dir / "{}.tfrecords".format(opt.dataset)))
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    crf_matrix = np.load(opt.crf) if opt.crf else None
    # init loader config
    train_data, _, _ = Run.fetch_datasets(data_config_file, opt)
    train_config, _, _ = Run.init_loader_config(opt)
    loader = QuickLoader(train_data,
                         opt.method,
                         train_config,
                         n_threads=opt.threads,
                         augmentation=opt.augment)
    it = loader.make_one_shot_iterator(opt.memory_limit, shuffle=True)
    with tqdm.tqdm(it, unit='batch', ascii=True) as r:
        for items in r:
            label, feature, names = items[:3]
            # label is usually HR image, feature is usually LR image
            batch_label = np.split(label, label.shape[0])
            batch_feature = np.split(feature, feature.shape[0])
            batch_name = np.split(names, names.shape[0])
            for hr, lr, name in zip(batch_label, batch_feature, batch_name):
                hr = np.squeeze(hr)
                lr = np.squeeze(lr)
                name = np.squeeze(name)
                with io.BytesIO() as fp:
                    Image.fromarray(hr, 'RGB').save(fp, format='png')
                    fp.seek(0)
                    hr_png = fp.read()
                with io.BytesIO() as fp:
                    Image.fromarray(lr, 'RGB').save(fp, format='png')
                    fp.seek(0)
                    lr_png = fp.read()
                lr_post = process(lr, crf_matrix, (opt.sigma[0], opt.sigma[1]))
                with io.BytesIO() as fp:
                    if opt.jpeg_quality:
                        Image.fromarray(lr_post,
                                        'RGB').save(fp,
                                                    format='jpeg',
                                                    quality=opt.jpeg_quality)
                    else:
                        Image.fromarray(lr_post, 'RGB').save(fp, format='png')
                    fp.seek(0)
                    post_png = fp.read()
                label = "{}_{}_{}".format(*name).encode()
                make_tensor_label_records(
                    [hr_png, lr_png, label, post_png],
                    ["image/hr", "image/lr", "name", "image/post"], writer)
class SRTrainer(Env):
    v = Config()

    def query_config(self, config, **kwargs):
        config = Config(config or {})
        config.update(kwargs)
        self.v.epochs = config.epochs or 1  # total epochs
        self.v.batch_shape = config.batch_shape or [1, -1, -1, -1]
        self.v.steps = config.steps or 200
        self.v.val_steps = config.val_steps or -1
        self.v.lr = config.lr or 1e-4  # learning rate
        self.v.lr_schedule = config.lr_schedule
        self.v.memory_limit = config.memory_limit
        self.v.inference_results_hooks = config.inference_results_hooks or []
        self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1
        self.v.traced_val = config.traced_val
        self.v.ensemble = config.ensemble
        self.v.cuda = config.cuda
        self.v.map_location = 'cuda:0' if config.cuda and torch.cuda.is_available(
        ) else 'cpu'
        self.v.caching = config.caching
        return self.v

    def fit_init(self) -> bool:
        v = self.v
        v.epoch = self._restore()
        if v.epoch >= v.epochs:
            LOG.info(f'Found pre-trained epoch {v.epoch}>=target {v.epochs},'
                     ' quit fitting.')
            return False
        LOG.info(f'Fitting: {self.model.name.upper()}')
        if self._logd:
            v.writer = Summarizer(str(self._logd), self.model.name)
        return True

    def fit_close(self):
        # flush all pending summaries to disk
        if isinstance(self.v.writer, Summarizer):
            self.v.writer.close()
        LOG.info(f'Training {self.model.name.upper()} finished.')

    def fit(self, loaders, config, **kwargs):
        v = self.query_config(config, **kwargs)
        v.train_loader, v.val_loader = loaders
        if not self.fit_init():
            return
        mem = v.memory_limit
        for epoch in range(self.last_epoch + 1, v.epochs + 1):
            v.epoch = epoch
            train_iter = v.train_loader.make_one_shot_iterator(
                v.batch_shape,
                v.steps,
                shuffle=True,
                memory_limit=mem,
                caching=v.caching)
            v.train_loader.prefetch(shuffle=True, memory_usage=mem)
            v.avg_meas = {}
            if v.lr_schedule and callable(v.lr_schedule):
                v.lr = v.lr_schedule(steps=v.epoch)
            LOG.info(f"| Epoch: {v.epoch}/{v.epochs} | LR: {v.lr:.2g} |")
            with tqdm.tqdm(train_iter, unit='batch', ascii=True) as r:
                self.model.to_train()
                for items in r:
                    self.fn_train_each_step(items)
                    r.set_postfix(v.loss)
            for _k, _v in v.avg_meas.items():
                _v = np.mean(_v)
                if isinstance(self.v.writer, Summarizer):
                    v.writer.scalar(_k, _v, step=v.epoch, collection='train')
                LOG.info(f"| Epoch average {_k} = {_v:.6f} |")
            if v.epoch % v.validate_every_n_epoch == 0 and v.val_loader:
                # Hard-coded memory limitation for validating
                self.benchmark(v.val_loader, v, memory_limit='1GB')
            self._save_model(v.epoch)
        self.fit_close()

    def fn_train_each_step(self, pack):
        v = self.v
        feature = to_tensor(pack['lr'], v.cuda)
        label = to_tensor(pack['hr'], v.cuda)
        loss = self.model.train([feature], [label], v.lr)
        for _k, _v in loss.items():
            v.avg_meas[_k] = \
              v.avg_meas[_k] + [_v] if v.avg_meas.get(_k) else [_v]
            loss[_k] = '{:08.5f}'.format(_v)
        v.loss = loss

    def benchmark(self, loader, config, **kwargs):
        """Benchmark/validate the model.

    Args:
        loader: a loader for enumerating LR images
        config: benchmark configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
        v = self.query_config(config, **kwargs)
        self._restore(config.epoch, v.map_location)
        v.mean_metrics = {}
        v.loader = loader
        it = v.loader.make_one_shot_iterator(v.batch_shape,
                                             v.val_steps,
                                             shuffle=not v.traced_val,
                                             memory_limit=v.memory_limit,
                                             caching=v.caching)
        self.model.to_eval()
        for items in tqdm.tqdm(it, 'Test', ascii=True):
            with torch.no_grad():
                self.fn_benchmark_each_step(items)
        log_message = str()
        for _k, _v in v.mean_metrics.items():
            _v = np.mean(_v)
            if isinstance(self.v.writer, Summarizer):
                v.writer.scalar(_k, _v, step=v.epoch, collection='eval')
            log_message += f"{_k}: {_v:.6f}, "
        log_message = log_message[:-2] + "."
        LOG.info(log_message)

    def fn_benchmark_each_step(self, pack):
        v = self.v
        feature = to_tensor(pack['lr'], v.cuda)
        label = to_tensor(pack['hr'], v.cuda)
        with torch.set_grad_enabled(False):
            outputs, metrics = self.model.eval([feature], [label],
                                               epoch=v.epoch)
        for _k, _v in metrics.items():
            if _k not in v.mean_metrics:
                v.mean_metrics[_k] = []
            v.mean_metrics[_k] += [_v]
        outputs = [from_tensor(x) for x in outputs]
        for fn in v.inference_results_hooks:
            outputs = fn(outputs, names=pack['name'])
            if outputs is None:
                break

    def infer(self, loader, config, **kwargs):
        """Infer SR images.

    Args:
        loader: a loader for enumerating LR images
        config: inferring configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
        v = self.query_config(config, **kwargs)
        self._restore(config.epoch, v.map_location)
        it = loader.make_one_shot_iterator(v.batch_shape, -1)
        if hasattr(it, '__len__'):
            if len(it) == 0:
                return
            LOG.info(f"Inferring {self.model.name} at epoch {self.last_epoch}")
        # use original images in inferring
        self.model.to_eval()
        for items in tqdm.tqdm(it, 'Infer', ascii=True):
            with torch.no_grad():
                self.fn_infer_each_step(items)

    def fn_infer_each_step(self, pack):
        v = self.v
        with torch.set_grad_enabled(False):
            if v.ensemble:
                # add self-ensemble boosting metric score
                feature_ensemble = Ensembler.expand(pack['lr'])
                outputs_ensemble = []
                for f in feature_ensemble:
                    f = to_tensor(f, v.cuda)
                    y, _ = self.model.eval([f])
                    y = [from_tensor(x) for x in y]
                    outputs_ensemble.append(y)
                outputs = []
                for i in range(len(outputs_ensemble[0])):
                    outputs.append([j[i] for j in outputs_ensemble])
                outputs = Ensembler.merge(outputs)
            else:
                feature = to_tensor(pack['lr'], v.cuda)
                outputs, _ = self.model.eval([feature])
                outputs = [from_tensor(x) for x in outputs]
        for fn in v.inference_results_hooks:
            outputs = fn(outputs, names=pack['name'])
            if outputs is None:
                break