Ejemplo n.º 1
0
 def __init__(self,
              name='drsr_v2',
              noise_config=None,
              weights=(1, 10, 1e-5),
              level=1,
              mean_shift=(0, 0, 0),
              arch=None,
              auto_shift=None,
              **kwargs):
     super(DRSR, self).__init__(**kwargs)
     self.name = name
     self.noise = Config(scale=0, offset=0, penalty=0.5, max=0, layers=7)
     if isinstance(noise_config, (dict, Config)):
         self.noise.update(**noise_config)
         self.noise.crf = np.load(self.noise.crf)
         self.noise.offset = to_list(self.noise.offset, 4)
         self.noise.offset = [x / 255 for x in self.noise.offset]
         self.noise.max /= 255
     self.weights = weights
     self.level = level
     if mean_shift is not None:
         self.norm = partial(_normalize, shift=mean_shift)
         self.denorm = partial(_denormalize, shift=mean_shift)
     self.arch = arch
     self.auto = auto_shift
     self.to_sum = []
Ejemplo n.º 2
0
def dummy_test_config():
    d = Config(a=1, b=2)
    d.update(a=2, b=3)
    d.a = 9
    d.update(Config(b=6, f=5))
    d.pop('b')
    print(d)
Ejemplo n.º 3
0
 def __init__(self,
              name='drsr',
              n_cb=4,
              n_crb=4,
              noise_config=None,
              weights=(1, 0.5, 0.05, 1e-3),
              finetune=2000,
              mean_shift=False,
              **kwargs):
     super(DRSR, self).__init__(**kwargs)
     self.name = name
     self.n_cb = n_cb
     self.n_crb = n_crb
     self.weights = weights
     self.finetune = finetune
     self.mean_shift = mean_shift
     self.noise = Config(scale=0, offset=0, penalty=0.7, max=0.2, layers=7)
     if isinstance(noise_config, (dict, Config)):
         self.noise.update(**noise_config)
         if self.noise.type == 'crf':
             self.noise.crf = np.load(self.noise.crf)
         self.noise.offset /= 255
         self.noise.max /= 255
     if 'tfrecords' in kwargs:
         self.tfr = kwargs['tfrecords']
         self._trainer = DrTrainer
 def test_config(self):
     d = Config(a=1, b=2)
     self.assertTrue(hasattr(d, 'a'))
     self.assertTrue(hasattr(d, 'b'))
     self.assertTrue(hasattr(d, 'non-exist'))
     self.assertIs(d.a, 1)
     self.assertIs(d.b, 2)
     d.update(a=2, b=3)
     self.assertIs(d.a, 2)
     self.assertIs(d.b, 3)
     d.a = 9
     self.assertIs(d.a, 9)
     d.update(Config(b=6, f=5))
     self.assertIs(d.b, 6)
     self.assertIs(d.f, 5)
     d.pop('b')
     self.assertIsNone(d.b)
Ejemplo n.º 5
0
 def query_config(self, config, **kwargs) -> Config:
   config = Config(config or {})
   config.update(kwargs)  # override parameters
   self.v.epoch = config.epoch  # current epoch
   self.v.epochs = config.epochs or 1  # total epochs
   self.v.lr = config.lr or 1e-4  # learning rate
   self.v.batch_shape = config.batch_shape or [1, -1, -1, -1]
   self.v.steps = config.steps or 200
   self.v.val_steps = config.val_steps or -1
   self.v.lr_schedule = config.lr_schedule
   self.v.memory_limit = config.memory_limit
   self.v.inference_results_hooks = config.inference_results_hooks or []
   self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1
   self.v.traced_val = config.traced_val
   self.v.ensemble = config.ensemble
   self.v.cuda = config.cuda
   return self.v
Ejemplo n.º 6
0
 def __init__(self,
              name='gan',
              patch_size=32,
              z_dim=128,
              init_filter=512,
              linear=False,
              norm_g=None,
              norm_d=None,
              use_bias=False,
              optimizer=None,
              arch=None,
              nd_iter=1,
              **kwargs):
     super(GAN, self).__init__(**kwargs)
     self.name = name
     self._trainer = GanTrainer
     self.output_size = patch_size
     self.z_dim = z_dim
     self.init_filter = init_filter
     self.linear = linear
     self.bias = use_bias
     self.nd_iter = nd_iter
     if isinstance(norm_g, str):
         self.bn = np.any([word in norm_g for word in ('bn', 'batch')])
         self.sn = np.any([word in norm_g for word in ('sn', 'spectral')])
     self.d_outputs = []  # (real, fake)
     self.g_outputs = []  # (real, fake)
     # monitor probability of being real and fake
     self.p_fake = None
     self.p_real = None
     self.opt = optimizer
     if self.opt is None:
         self.opt = Config(name='adam')
     if arch is None or arch == 'dcgan':
         self.G = self.dcgan_g
         self.D = Discriminator.dcgan_d(
             self, [patch_size, patch_size, self.channel],
             norm=norm_d,
             name_or_scope='D')
     elif arch == 'resnet':
         self.G = self.resnet_g
         self.D = Discriminator.resnet_d(
             self, [patch_size, patch_size, self.channel],
             times_pooling=4,
             norm=norm_d,
             name_or_scope='D')
Ejemplo n.º 7
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)
    overwrite_from_env(opt)
    data_config_file = Path(flags.data_config)
    if not data_config_file.exists():
        raise FileNotFoundError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        if opt.parameter:
            model_config_file = Path(opt.parameter)
        else:
            model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}')
        if model_config_file.exists():
            opt.update(compat_param(Config(str(model_config_file))))
    # get model parameters from pre-defined YAML file
    model_params = opt.get(opt.model, {})
    suppress_opt_by_args(model_params, *args)
    opt.update(model_params)
    # construct model
    model = get_model(opt.model)(**model_params)
    if opt.cuda:
        model.cuda()
    if opt.pretrain:
        model.load(opt.pretrain)
    root = f'{opt.save_dir}/{opt.model}'
    if opt.comment:
        root += '_' + opt.comment
    root = Path(root)

    datasets = load_datasets(data_config_file)
    try:
        test_datas = [datasets[t.upper()]
                      for t in opt.test] if opt.test else []
    except KeyError:
        test_datas = [Config(test=Config(lr=Dataset(*opt.test)), name='infer')]
        if opt.video:
            test_datas[0].test.lr.use_like_video_()
    # enter model executor environment
    with model.get_executor(root) as t:
        for data in test_datas:
            run_benchmark = False if data.test.hr is None else True
            if run_benchmark:
                ld = Loader(data.test.hr,
                            data.test.lr,
                            opt.scale,
                            threads=opt.threads)
            else:
                ld = Loader(data.test.hr, data.test.lr, threads=opt.threads)
            if opt.channel == 1:
                # convert data color space to grayscale
                ld.set_color_space('hr', 'L')
                ld.set_color_space('lr', 'L')
            config = t.query_config(opt)
            config.inference_results_hooks = [
                save_inference_images(root / data.name, opt.output_index,
                                      opt.auto_rename)
            ]
            if run_benchmark:
                t.benchmark(ld, config)
            else:
                t.infer(ld, config)
        if opt.export:
            t.export(opt.export)
Ejemplo n.º 8
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()  # An EasyDict object
    # overwrite flag values into opt object
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)
    # fetch dataset descriptions
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise FileNotFoundError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        if opt.parameter:
            model_config_file = Path(opt.parameter)
        else:
            model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}')
        if model_config_file.exists():
            opt.update(compat_param(Config(str(model_config_file))))
    # get model parameters from pre-defined YAML file
    model_params = opt.get(opt.model, {})
    suppress_opt_by_args(model_params, *args)
    opt.update(model_params)
    # construct model
    model = get_model(opt.model)(**model_params)
    if opt.cuda:
        model.cuda()
    if opt.pretrain:
        model.load(opt.pretrain)
    root = f'{opt.save_dir}/{opt.model}'
    if opt.comment:
        root += '_' + opt.comment

    dataset = load_datasets(data_config_file, opt.dataset)
    # construct data loader for training
    lt = Loader(dataset.train.hr,
                dataset.train.lr,
                opt.scale,
                threads=opt.threads)
    lt.image_augmentation()
    # construct data loader for validating
    lv = None
    if dataset.val is not None:
        lv = Loader(dataset.val.hr,
                    dataset.val.lr,
                    opt.scale,
                    threads=opt.threads)
    lt.cropper(RandomCrop(opt.scale))
    if opt.traced_val and lv is not None:
        lv.cropper(CenterCrop(opt.scale))
    elif lv is not None:
        lv.cropper(RandomCrop(opt.scale))
    if opt.channel == 1:
        # convert data color space to grayscale
        lt.set_color_space('hr', 'L')
        lt.set_color_space('lr', 'L')
        if lv is not None:
            lv.set_color_space('hr', 'L')
            lv.set_color_space('lr', 'L')
    # enter model executor environment
    with model.get_executor(root) as t:
        config = t.query_config(opt)
        if opt.lr_decay:
            config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
        t.fit([lt, lv], config)
        if opt.export:
            t.export(opt.export)
Ejemplo n.º 9
0
class DRSR(SuperResolution):
    def __init__(self,
                 name='drsr',
                 n_cb=4,
                 n_crb=4,
                 noise_config=None,
                 weights=(1, 0.5, 0.05, 1e-3),
                 finetune=2000,
                 mean_shift=False,
                 **kwargs):
        super(DRSR, self).__init__(**kwargs)
        self.name = name
        self.n_cb = n_cb
        self.n_crb = n_crb
        self.weights = weights
        self.finetune = finetune
        self.mean_shift = mean_shift
        self.noise = Config(scale=0, offset=0, penalty=0.7, max=0.2, layers=7)
        if isinstance(noise_config, (dict, Config)):
            self.noise.update(**noise_config)
            if self.noise.type == 'crf':
                self.noise.crf = np.load(self.noise.crf)
            self.noise.offset /= 255
            self.noise.max /= 255
        if 'tfrecords' in kwargs:
            self.tfr = kwargs['tfrecords']
            self._trainer = DrTrainer

    def display(self):
        # stats = tf.profiler.profile()
        # LOG.info("Total parameters: {}".format(stats.total_parameters))
        LOG.info("Noisy scaling {}, bias sigma {}".format(
            self.noise.scale, self.noise.offset))
        LOG.info("Using {}".format(self.trainer))

    def _dncnn(self, inputs):
        n = self.noise
        with tf.variable_scope('Dncnn'):
            x = inputs
            for _ in range(6):
                x = self.bn_relu_conv2d(x, 64, 3)
            x = self.conv2d(x, self.channel, 3)
        return x

    def cascade_block(self,
                      inputs,
                      noise,
                      filters=64,
                      depth=4,
                      scope=None,
                      reuse=None):
        def _noise_condition(nc_inputs, layers=2):
            with tf.variable_scope(None, 'NCL'):
                t = noise
                for _ in range(layers - 1):
                    t = self.relu_conv2d(t, 64, 3)
                t = self.conv2d(t, 64, 3)
                gamma = tf.reduce_mean(t, [1, 2], keepdims=True)
                t = noise
                for _ in range(layers - 1):
                    t = self.relu_conv2d(t, 64, 3)
                beta = self.conv2d(t, 64, 3)
            return nc_inputs * gamma + beta

        def _cond_resblock(cr_inputs, kernel_size):
            with tf.variable_scope(None, 'CRB'):
                pre_inputs = cr_inputs
                cr_inputs = self.relu_conv2d(cr_inputs, filters, kernel_size)
                cr_inputs = _noise_condition(cr_inputs)
                cr_inputs = self.relu_conv2d(cr_inputs, filters, kernel_size)
                cr_inputs = _noise_condition(cr_inputs)
                return pre_inputs + cr_inputs

        with tf.variable_scope(scope, 'CB', reuse=reuse):
            feat = [inputs]
            for i in range(depth):
                x = _cond_resblock(inputs, 3)
                feat.append(x)
                inputs = self.conv2d(tf.concat(feat, axis=-1),
                                     filters,
                                     1,
                                     kernel_initializer='he_uniform')
            # inputs = self.conv2d(inputs, filters, 3)
            return inputs

    def _upsample(self, inputs, noise):
        x = [self.conv2d(inputs, 64, 7)]
        for i in range(self.n_cb):
            x += [self.cascade_block(x[i], noise, depth=self.n_crb)]
        # bottleneck
        df = [
            self.conv2d(n, 32, 1, kernel_initializer='he_uniform')
            for n in x[:-1]
        ]
        df.append(x[-1])
        summary_tensor_image(x[-1], 'last_before_bn')
        bottleneck = tf.concat(df, axis=-1, name='bottleneck')
        sr = self.upscale(bottleneck, direct_output=False)
        summary_tensor_image(sr, 'after_bn')
        sr = self.conv2d(sr, self.channel, 3)
        return sr, x

    def _unet(self, inputs, noise):
        with tf.variable_scope('Unet'):
            x0 = self.conv2d(inputs, 64, 7)
            x1 = self.cascade_block(x0, noise, depth=self.n_crb)
            x1s = tf.layers.average_pooling2d(x1, 2, 2)
            n1s = tf.layers.average_pooling2d(noise, 2, 2)
            x2 = self.cascade_block(x1s, n1s, depth=self.n_crb)
            x2s = tf.layers.average_pooling2d(x2, 2, 2)
            n2s = tf.layers.average_pooling2d(noise, 4, 4)
            x3 = self.cascade_block(x2s, n2s, depth=self.n_crb)
            x3u = self.deconv2d(x3, 64, 3, strides=2)
            x3u1 = tf.concat([x3u, x1s], -1)
            x3u2 = self.conv2d(x3u1, 64, 3)
            x4 = self.cascade_block(x3u2, n1s, depth=self.n_crb)
            x4u = self.deconv2d(x4, 64, 3, strides=2)
            x4u1 = tf.concat([x4u, x0], -1)
            x4u2 = self.conv2d(x4u1, 64, 3)
            x5 = self.conv2d(x4u2, self.channel, 3)
        return x5, None

    def _get_noise(self, inputs):
        n = self.noise
        if n.type == 'gaussian':
            sigma = tf.random_uniform([], maxval=n.max)
            noise = tf.random_normal(tf.shape(inputs), stddev=sigma)
            img = inputs + noise
            return img, noise
        elif n.type == 'crf':
            crf = tf.convert_to_tensor(n.crf['crf'])
            icrf = tf.convert_to_tensor(n.crf['icrf'])
            i = tf.random_uniform([], 0, crf.shape[0], dtype=tf.int32)
            irr = Noise.tf_camera_response_function(inputs, icrf[i], max_val=1)
            noise = Noise.tf_gaussian_poisson_noise(irr, max_c=n.max)
            img = Noise.tf_camera_response_function(irr + noise,
                                                    crf[i],
                                                    max_val=1)
            return img, img - inputs
        else:
            raise TypeError(n.type)

    def build_graph(self):
        super(DRSR, self).build_graph()
        inputs_norm = _normalize(self.inputs_preproc[-1])
        labels_norm = _normalize(self.label[-1])
        if self.mean_shift:
            inputs_norm -= _MEAN / 255
            labels_norm -= _MEAN / 255
        n = self.noise
        inputs_noise, noise = self._get_noise(inputs_norm)
        nn = self._upsample
        with tf.variable_scope('Offset'):
            x = inputs_norm
            for _ in range(n.layers):
                x = self.relu_conv2d(
                    x,
                    64,
                    3,
                    kernel_initializer=tf.initializers.random_normal(
                        stddev=0.01))
            offset = self.conv2d(
                x,
                self.channel,
                3,
                kernel_initializer=tf.initializers.random_normal(stddev=0.01))
            offset *= Noise.tf_gaussian_noise(offset, n.offset2)

        with tf.variable_scope(self.name):
            zero = self._dncnn(inputs_norm)
            zero_shift = zero + offset * n.scale + \
                         Noise.tf_gaussian_noise(zero, n.offset)
            clean = nn(inputs_norm, zero_shift)
        with tf.variable_scope(self.name, reuse=True):
            noisy = self._dncnn(inputs_noise)
            dirty = nn(inputs_noise, noisy)
        if self.finetune == -1:
            with tf.variable_scope(self.name, reuse=True):
                s = 2
                inputs_s2 = tf.layers.average_pooling2d(inputs_norm, s, s)
                zero_s2 = self._dncnn(inputs_s2)
                zero_shift_s2 = zero_s2 + Noise.tf_gaussian_noise(
                    zero_s2, n.offset)
                clean_s2 = nn(inputs_s2, zero_shift_s2)
                noise_s2 = inputs_norm - clean_s2[0]
            with tf.variable_scope('Fine'):
                x = self.conv2d(inputs_norm, 64, 3)
                x = self.cascade_block(x, noise_s2, depth=6)
                x = self.conv2d(x, self.channel, 3)
                clean_fine = [x, x]
            self.outputs.append(_denormalize(clean_s2[0]))
            self.outputs.append(_denormalize(clean_fine[0]))
        else:
            self.outputs.append(_denormalize(tf.abs(zero)))
            self.outputs.append(_denormalize(clean[0]))

        if self.mean_shift:
            self.outputs = [x + _MEAN for x in self.outputs]

        def loss1():
            l1_with_noise = tf.losses.absolute_difference(
                dirty[0], labels_norm)
            l1_fine_tune = tf.losses.absolute_difference(clean[0], labels_norm)
            penalty = tf.clip_by_value(2 * tf.ceil(tf.nn.relu(noisy - noise)),
                                       0, 1)
            penalty = tf.abs(self.noise.penalty - penalty)
            noise_identity = penalty * tf.squared_difference(noisy, noise)
            noise_identity = tf.reduce_mean(noise_identity)
            noise_tv = tf.reduce_mean(tf.image.total_variation(noisy))
            # tv clamp
            l_tv_max = tf.nn.relu(noise_tv - 1000)**2
            l_tv_min = tf.nn.relu(100 - noise_tv)**2
            noise_tv += l_tv_max + l_tv_min
            loss = tf.stack([l1_with_noise, noise_identity, noise_tv])
            loss *= self.weights[:-1]
            loss = tf.reduce_sum(loss)
            self.train_metric['l1/noisy'] = l1_with_noise
            self.train_metric['l1/finet'] = l1_fine_tune
            self.train_metric['ni'] = noise_identity
            self.train_metric['nt'] = noise_tv

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            var_g = tf.trainable_variables(self.name)
            var_o = tf.trainable_variables('Offset')
            with tf.control_dependencies(update_ops):
                op = tf.train.AdamOptimizer(self.learning_rate, 0.9)
                op = op.minimize(loss, self.global_steps, var_list=var_g)
                self.loss.append(op)
                op = tf.train.AdamOptimizer(self.learning_rate, 0.9)
                op = op.minimize(l1_fine_tune,
                                 self.global_steps,
                                 var_list=var_o)
                self.loss.append(op)

        def loss2():
            l1_clean = tf.losses.mean_squared_error(clean[0], labels_norm)
            var_g = tf.trainable_variables(self.name)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                op = tf.train.AdamOptimizer(self.learning_rate, 0.9)
                op = op.minimize(l1_clean, self.global_steps, var_list=var_g)
                self.loss += [op, op]
            self.train_metric['l1/tune'] = l1_clean

        def loss3():
            l1_clean = tf.losses.mean_squared_error(clean_fine[0], labels_norm)
            var_f = tf.trainable_variables('Fine')
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                op = tf.train.AdamOptimizer(self.learning_rate, 0.9)
                op = op.minimize(l1_clean, self.global_steps, var_list=var_f)
                self.loss += [op, op]
            self.train_metric['l1/tune'] = l1_clean
            tf.summary.image('hr/coarse', _clip(self.outputs[-2]))

        with tf.name_scope('Loss'):
            if self.finetune == -1:
                loss3()
            elif 'DrTrainer' in str(self.trainer):
                loss2()
            else:
                loss1()
        self.metrics['psnr1'] = tf.reduce_mean(
            tf.image.psnr(self.label[-1], self.outputs[-1], max_val=255))
        tf.summary.image('noisy/zero', zero)

    def build_loss(self):
        pass

    def build_summary(self):
        super(DRSR, self).build_summary()
        tf.summary.image('lr/input', self.inputs[-1])
        tf.summary.image('hr/fine', _clip(self.outputs[-1]))
        tf.summary.image('hr/label', _clip(self.label[0]))

    def build_saver(self):
        var_g = tf.global_variables(self.name)
        steps = [self.global_steps]
        loss = tf.global_variables('Loss')
        self.savers.update(drsr_g=tf.train.Saver(var_g, max_to_keep=1),
                           misc=tf.train.Saver(steps + loss, max_to_keep=1))
        if self.finetune == -1:
            var_f = tf.global_variables('Fine')
            self.savers.update(drsr_f=tf.train.Saver(var_f, max_to_keep=1))

    def train_batch(self, feature, label, learning_rate=1e-4, **kwargs):
        epochs = kwargs.get('epochs')
        if epochs < self.finetune:
            loss = self.loss[0]
        else:
            loss = self.loss[1]
        return super(DRSR, self).train_batch(feature,
                                             label,
                                             learning_rate,
                                             loss=loss)
Ejemplo n.º 10
0
class VSR(Trainer):
  """Default trainer for task SISR or VSR"""
  v = Config()  # local variables
  """=======================================
      components, sub-functions, helpers
     =======================================
  """

  def query_config(self, config, **kwargs) -> Config:
    config = Config(config or {})
    config.update(kwargs)  # override parameters
    self.v.epoch = config.epoch  # current epoch
    self.v.epochs = config.epochs or 1  # total epochs
    self.v.lr = config.lr or 1e-4  # learning rate
    self.v.batch_shape = config.batch_shape or [1, -1, -1, -1]
    self.v.steps = config.steps or 200
    self.v.val_steps = config.val_steps or -1
    self.v.lr_schedule = config.lr_schedule
    self.v.memory_limit = config.memory_limit
    self.v.inference_results_hooks = config.inference_results_hooks or []
    self.v.validate_every_n_epoch = config.validate_every_n_epoch or 1
    self.v.traced_val = config.traced_val
    self.v.ensemble = config.ensemble
    self.v.cuda = config.cuda
    self.v.caching = config.caching_dataset
    return self.v

  def fit_init(self) -> bool:
    v = self.v
    v.sess = self._restore()
    if self.last_epoch >= v.epochs:
      LOG.info(f'Found pre-trained epoch {v.epoch}>=target {v.epochs},'
               ' quit fitting.')
      return False
    LOG.info('Fitting: {}'.format(self.model.name.upper()))
    v.summary_writer = tf.summary.FileWriter(
        str(self._logd), graph=tf.get_default_graph())
    v.global_step = self.model.global_steps.eval()
    return True

  def fit_close(self):
    # flush all pending summaries to disk
    if self.v.summary_writer:
      self.v.summary_writer.close()
    LOG.info(f'Training {self.model.name.upper()} finished.')

  def fn_train_each_epoch(self):
    v = self.v
    mem = v.memory_limit
    train_iter = v.train_loader.make_one_shot_iterator(v.batch_shape,
                                                       v.steps,
                                                       shuffle=True,
                                                       memory_limit=mem,
                                                       caching=v.caching)
    v.train_loader.prefetch(v.memory_limit)
    v.avg_meas = {}
    if v.lr_schedule and callable(v.lr_schedule):
      v.lr = v.lr_schedule(steps=v.global_step)
    LOG.info(f"| Epoch: {v.epoch}/{v.epochs} | LR: {v.lr:.2g} |")
    with tqdm.tqdm(train_iter, unit='batch', ascii=True) as r:
      for items in r:
        self.fn_train_each_step(items)
        r.set_postfix(v.loss)
    for _k, _v in v.avg_meas.items():
      LOG.info(f"| Epoch average {_k} = {np.mean(_v):.6f} |")
    if v.epoch % v.validate_every_n_epoch == 0 and v.val_loader:
      self.benchmark(v.val_loader, v, epoch=v.epoch, memory_limit='1GB')
      v.summary_writer.add_summary(self.model.summary(), v.global_step)
    self._save_model(v.sess, v.epoch)

  def fn_train_each_step(self, pack):
    v = self.v
    loss = self.model.train_batch(pack['lr'], pack['hr'], learning_rate=v.lr,
                                  epochs=v.epoch)
    v.global_step = self.model.global_steps.eval()
    for _k, _v in loss.items():
      v.avg_meas[_k] = \
        v.avg_meas[_k] + [_v] if v.avg_meas.get(_k) else [_v]
      loss[_k] = '{:08.5f}'.format(_v)
    v.loss = loss

  def fn_infer_each_step(self, pack):
    v = self.v
    if v.ensemble:
      # add self-ensemble boosting metric score
      feature_ensemble = _ensemble_expand(pack['lr'])
      outputs_ensemble = []
      for f in feature_ensemble:
        y, _ = self.model.test_batch(f, None)
        outputs_ensemble.append(y)
      outputs = []
      for i in range(len(outputs_ensemble[0])):
        outputs.append([j[i] for j in outputs_ensemble])
      outputs = _ensemble_reduce_mean(outputs)
    else:
      outputs, _ = self.model.test_batch(pack['lr'], None)
    for fn in v.inference_results_hooks:
      outputs = fn(outputs, names=pack['name'])
      if outputs is None:
        break

  def fn_benchmark_each_step(self, pack):
    v = self.v
    outputs, metrics = self.model.test_batch(pack['lr'], pack['hr'],
                                             epochs=v.epoch)
    for _k, _v in metrics.items():
      if _k not in v.mean_metrics:
        v.mean_metrics[_k] = []
      v.mean_metrics[_k] += [_v]
    for fn in v.inference_results_hooks:
      outputs = fn(outputs, names=pack['name'])
      if outputs is None:
        break

  def fn_benchmark_body(self):
    v = self.v
    it = v.loader.make_one_shot_iterator(v.batch_shape, v.val_steps,
                                         shuffle=not v.traced_val,
                                         memory_limit=v.memory_limit,
                                         caching=v.caching)
    for items in tqdm.tqdm(it, 'Test', ascii=True):
      self.fn_benchmark_each_step(items)

  """=======================================
      Interface: fit, benchmark, infer
     =======================================
  """

  def fit(self, loaders, config, **kwargs):
    """Fit the model.

    Args:
        loaders: a tuple of 2 loaders, the 1st one is used for training,
          and the 2nd one is used for validating.
        config: fitting configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
    v = self.query_config(config, **kwargs)
    v.train_loader, v.val_loader = loaders
    if not self.fit_init():
      return
    for epoch in range(self.last_epoch + 1, v.epochs + 1):
      v.epoch = epoch
      self.fn_train_each_epoch()
    self.fit_close()

  def infer(self, loader, config, **kwargs):
    """Infer SR images.

    Args:
        loader: a loader for enumerating LR images
        config: inferring configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
    v = self.query_config(config, **kwargs)
    self._restore()
    it = loader.make_one_shot_iterator(v.batch_shape, -1)
    if hasattr(it, '__len__'):
      if len(it):
        LOG.info('Inferring {} at epoch {}'.format(
            self.model.name, self.last_epoch))
      else:
        return
    # use original images in inferring
    for items in tqdm.tqdm(it, 'Infer', ascii=True):
      self.fn_infer_each_step(items)

  def benchmark(self, loader, config, **kwargs):
    """Benchmark/validate the model.

    Args:
        loader: a loader for enumerating LR images
        config: benchmark configuration, an instance of `Util.Config.Config`
        kwargs: additional arguments to override the same ones in config.
    """
    v = self.query_config(config, **kwargs)
    self._restore()
    v.mean_metrics = {}
    v.loader = loader
    self.fn_benchmark_body()
    log_message = str()
    for _k, _v in v.mean_metrics.items():
      _v = np.mean(_v)
      log_message += f"{_k}: {_v:.6f}, "
    log_message = log_message[:-2] + "."
    LOG.info(log_message)
Ejemplo n.º 11
0
class DRSR(SuperResolution):
  def __init__(self, name='drsr_v2', noise_config=None, weights=(1, 10, 1e-5),
               level=1, mean_shift=(0, 0, 0), arch=None, auto_shift=None,
               **kwargs):
    super(DRSR, self).__init__(**kwargs)
    self.name = name
    self.noise = Config(scale=0, offset=0, penalty=0.5, max=0, layers=7)
    if isinstance(noise_config, (dict, Config)):
      self.noise.update(**noise_config)
      self.noise.crf = np.load(self.noise.crf)
      self.noise.offset = to_list(self.noise.offset, 4)
      self.noise.offset = [x / 255 for x in self.noise.offset]
      self.noise.max /= 255
    self.weights = weights
    self.level = level
    if mean_shift is not None:
      self.norm = partial(_normalize, shift=mean_shift)
      self.denorm = partial(_denormalize, shift=mean_shift)
    self.arch = arch
    self.auto = auto_shift
    self.to_sum = []

  def display(self):
    LOG.info(str(self.noise))

  def noise_cond(self, inputs, noise, layers, scope='NCL'):
    with tf.variable_scope(None, scope):
      x = noise
      c = inputs.shape[-1]
      for _ in range(layers - 1):
        x = self.prelu_conv2d(x, 64, 3)
      x = self.conv2d(x, c, 3)
      gamma = tf.nn.sigmoid(x)
      x = noise
      for _ in range(layers - 1):
        x = self.prelu_conv2d(x, 64, 3)
      beta = self.conv2d(x, c, 3)
      return inputs * gamma + beta

  def cond_rb(self, inputs, noise, scope='CRB'):
    with tf.variable_scope(None, scope):
      x = self.prelu_conv2d(inputs, 64, 3)
      x = self.conv2d(x, 64, 3)
      x = self.noise_cond(x, noise, 3)
      if inputs.shape[-1] != x.shape[-1]:
        sc = self.conv2d(inputs, x.shape[-1], 1,
                         kernel_initializer='he_uniform')
      else:
        sc = inputs
      return sc + x

  def cond_rdb(self, inputs, noise, scope='CRDB'):
    with tf.variable_scope(None, scope):
      x0 = self.prelu_conv2d(inputs, 64, 3)
      x1 = self.prelu_conv2d(tf.concat([inputs, x0], -1), 64, 3)
      x2 = self.conv2d(tf.concat([inputs, x0, x1], -1), 64, 3)
      x = self.noise_cond(x2, noise, 3)
      if inputs.shape[-1] != x.shape[-1]:
        sc = self.conv2d(inputs, x.shape[-1], 1,
                         kernel_initializer='he_uniform')
      else:
        sc = inputs
      return sc + x

  def noise_estimate(self, inputs, scope='NoiseEstimator', reuse=None):
    n = self.noise
    with tf.variable_scope(None, scope, reuse=reuse):
      x = inputs
      for _ in range(n.layers):
        x = self.leaky_conv2d(x, 64, 3)
      x = self.conv2d(x, self.channel, 3)
      return x

  def noise_shift(self, inputs, layers, scope='NoiseShift', reuse=None):
    n = self.noise
    with tf.variable_scope(None, scope, reuse=reuse):
      x = inputs
      for _ in range(layers):
        x = self.leaky_conv2d(x, 64, 3)
      x = self.conv2d(x, self.channel, 3, activation=tf.nn.sigmoid)
      return x * Noise.tf_gaussian_noise(inputs, n.max)

  def local_net(self, inputs, noise, depth=4, scope='LC'):
    with tf.variable_scope(None, scope):
      fl = [inputs]
      x = inputs
      for i in range(depth):
        x = self.cond_rb(x, noise)
        fl.append(x)
        x = tf.concat(fl, axis=-1)
        x = self.conv2d(x, 64, 1, kernel_initializer='he_uniform')
      return x

  def local_net2(self, inputs, noise, depth=4, scope='LC'):
    with tf.variable_scope(None, scope):
      fl = [inputs]
      x = inputs
      for i in range(depth):
        x = self.cond_rdb(x, noise)
        fl.append(x)
        x = tf.concat(fl, axis=-1)
        x = self.conv2d(x, 64, 1, kernel_initializer='he_uniform')
      return x

  def global_net(self, inputs, noise, depth=4, scope='GC', reuse=None):
    with tf.variable_scope(None, scope, reuse=reuse):
      fl = [inputs]
      x = inputs
      for i in range(depth):
        if self.arch == 'concat':
          x = cascade_rdn(self, x, depth=3, use_ca=True)
        elif self.arch == 'crb':
          x = self.local_net(x, noise[i], 4)
        else:
          x = self.local_net2(x, noise[i], 3)
        if self.arch != 'crdb':
          fl.append(x)
          x = tf.concat(fl, axis=-1)
          x = self.conv2d(x, 64, 1, kernel_initializer='he_uniform')
      self.to_sum += fl
      if self.arch == 'crdb':
        x += inputs
      if self.auto:
        sr = self.upscale(x, direct_output=False, scale=4)
      else:
        sr = self.upscale(x, direct_output=False)
      sr = self.conv2d(sr, self.channel, 3)
      return sr, x

  def gen_noise(self, inputs, ntype, max1=0.06, max2=0.16):
    with tf.name_scope('GenNoise'):
      n = self.noise
      if ntype == 'gaussian':
        noise = Noise.tf_gaussian_noise(inputs, sigma_max=max1,
                                        channel_wise=False)
        return noise
      elif ntype == 'crf':
        crf = tf.convert_to_tensor(n.crf['crf'])
        icrf = tf.convert_to_tensor(n.crf['icrf'])
        i = tf.random_uniform([], 0, crf.shape[0], dtype=tf.int32)
        irr = Noise.tf_camera_response_function(inputs, icrf[i], max_val=1)
        noise = Noise.tf_gaussian_poisson_noise(irr, max_c=max1, max_s=max2)
        img = Noise.tf_camera_response_function(irr + noise, crf[i], max_val=1)
        return img - inputs
      else:
        raise TypeError(ntype)

  def net(self, inputs, level, scale=1, shift=(0, 0, 0, 0), reuse=None):
    with tf.variable_scope(self.name, reuse=reuse):
      level_outputs = []
      level_noise = []
      level_inputs = []
      for i in range(1, level + 1):
        with tf.variable_scope(f'Level{i:1d}'):
          noise_hyp = self.noise_estimate(inputs) * scale + \
                      Noise.tf_gaussian_noise(inputs, self.noise.offset[0])
          level_noise.append(noise_hyp)
          noise_hyp = [noise_hyp + shift[0],
                       noise_hyp + shift[1],
                       noise_hyp + shift[2],
                       noise_hyp + shift[3]]
          if i == 1:
            if self.arch == 'concat':
              inputs = tf.concat([inputs, noise_hyp[0]], axis=-1)
            entry = self.conv2d(inputs, 64, 3)
            entry = self.conv2d(entry, 64, 3)
            level_inputs.append(entry)
          y = self.global_net(level_inputs[-1], noise_hyp, 4)
          level_outputs.append(y[0])
          level_inputs.append(y[1])
      return level_noise, level_outputs

  def build_graph(self):
    super(DRSR, self).build_graph()
    inputs_norm = self.norm(self.inputs_preproc[-1])
    labels_norm = self.norm(self.label[-1])
    n = self.noise
    if n.valid:
      LOG.info("adding noise")
      awgn = self.gen_noise(inputs_norm, 'gaussian', n.max)
      gp = self.gen_noise(inputs_norm, 'crf', 5 / 255, n.max)
    else:
      awgn = gp = tf.zeros_like(inputs_norm)

    if self.level == 1:
      noise = awgn
    elif self.level == 2:
      noise = gp
    else:
      raise NotImplementedError("Unknown level!")
    with tf.variable_scope('Offset'):
      shift = []
      if not self.auto:
        for i in range(4):
          shift.append(Noise.tf_gaussian_noise(inputs_norm, n.offset[i]))
        var_shift = []
      else:
        for i in range(4):
          shift.append(self.noise_shift(inputs_norm, 8, f'NoiseShift_{i}'))
        var_shift = tf.trainable_variables('Offset')

    noise_hyp, outputs = self.net(inputs_norm + noise, 1, n.scale, shift)
    self.outputs += [tf.abs(x * 255) for x in noise_hyp + shift]
    self.outputs += [self.denorm(x) for x in outputs]

    l1_image = tf.losses.absolute_difference(outputs[-1], labels_norm)
    noise_abs_diff = tf.abs(noise_hyp[-1]) - tf.abs(noise)
    # 1: over estimated; 0: under estimated
    penalty = tf.ceil(tf.clip_by_value(noise_abs_diff, 0, 1))
    # 1 - n: over estimated; n: under estimated
    penalty = tf.abs(n.penalty - penalty)
    noise_error = penalty * tf.squared_difference(noise_hyp[-1], noise)
    l2_noise = tf.reduce_mean(noise_error)

    # tv clamp
    tv = tf.reduce_mean(tf.image.total_variation(noise_hyp[-1]))
    l_tv_max = tf.nn.relu(tv - 1000) ** 2
    l_tv_min = tf.nn.relu(200 - tv) ** 2
    tv = tv + l_tv_max + l_tv_min

    def loss_fn1():
      w = self.weights
      loss = l1_image * w[0] + l2_noise * w[1] + tv * w[2]
      var_to_opt = tf.trainable_variables(self.name + f"/Level1")
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        op = tf.train.AdamOptimizer(self.learning_rate, 0.9)
        op = op.minimize(loss, self.global_steps, var_list=var_to_opt)
        self.loss.append(op)

      self.train_metric['mae'] = l1_image
      self.train_metric['noise_error'] = l2_noise
      self.train_metric['tv'] = tv
      self.to_sum += noise_hyp

    def loss_fn2():
      w = self.weights
      tv_noise = [tf.reduce_mean(tf.image.total_variation(x)) for x in shift]
      tv_noise = tf.add_n(tv_noise) / 4
      tv_max = tf.nn.relu(tv_noise - 1000) ** 2
      tv_min = tf.nn.relu(200 - tv_noise) ** 2
      tv_noise += tv_max + tv_min
      loss = l1_image * w[0] + tv_noise * w[2]
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        op = tf.train.AdamOptimizer(self.learning_rate, 0.9)
        op = op.minimize(loss, self.global_steps, var_list=var_shift)
        self.loss.append(op)
      self.train_metric['mae'] = l1_image
      self.train_metric['tv'] = tv_noise

    with tf.name_scope('Loss'):
      if not self.auto:
        loss_fn1()
      else:
        loss_fn2()
      self.metrics['psnr'] = tf.reduce_mean(
        tf.image.psnr(self.label[-1], self.outputs[-1], max_val=255))
      self.metrics['ssim'] = tf.reduce_mean(
        tf.image.ssim(self.label[-1], self.outputs[-1], max_val=255))

  def build_loss(self):
    pass

  def build_summary(self):
    super(DRSR, self).build_summary()
    # tf.summary.image('lr/input', self.inputs[-1])
    tf.summary.image(f'hr/fine_1', clip_image(self.outputs[-1]))
    tf.summary.image('hr/label', clip_image(self.label[0]))

  def build_saver(self):
    var_misc = tf.global_variables('Loss') + [self.global_steps]
    self.savers.update(misc=tf.train.Saver(var_misc, max_to_keep=1))
    var_g = tf.global_variables(self.name + f"/Level1")
    self.savers.update({
      f"level_1": tf.train.Saver(var_g, max_to_keep=1)
    })
    if self.auto:
      self.savers.update(shift=tf.train.Saver(
        tf.global_variables('Offset'),
        max_to_keep=1))