Beispiel #1
0
    def build_graph(self):
        super(SRFEAT, self).build_graph()
        inputs_norm = _normalize(self.inputs_preproc[-1])
        label_norm = _normalize(self.label[-1])
        with tf.variable_scope(self.name):
            shallow_feature = self.prelu_conv2d(inputs_norm, self.F, 9)
            x = [shallow_feature]
            for _ in range(self.g_layers):
                x.append(
                    self.resblock(x[-1],
                                  self.F,
                                  3,
                                  activation='prelu',
                                  use_batchnorm=True))
            bottleneck = x[-1]
            for t in x[1:-1]:
                bottleneck += self.conv2d(t, self.F, 1)
            sr = self.upscale(bottleneck, direct_output=False, activator=prelu)
            sr = self.tanh_conv2d(sr, self.channel, 9)
            self.outputs.append(_denormalize(sr))

        disc_real = self.D(label_norm)
        disc_fake = self.D(sr)
        vgg_features = [self.vgg(self.outputs[0], self.vgg_layer)]
        vgg_features += [self.vgg(self.label[0], self.vgg_layer)]
        vgg_fake = self.DF(vgg_features[0])
        vgg_real = self.DF(vgg_features[1])

        with tf.name_scope('Loss'):
            loss_gen, loss_disc = loss_bce_gan(disc_real, disc_fake)
            vgg_loss_g, vgg_loss_d = loss_bce_gan(vgg_real, vgg_fake)
            mse = tf.losses.mean_squared_error(label_norm, sr)
            loss_d = loss_disc + vgg_loss_d
            loss_g = loss_gen + vgg_loss_g
            loss_vgg = tf.losses.mean_squared_error(*vgg_features)
            loss = tf.stack([loss_g, loss_vgg])
            loss = tf.reduce_sum(loss * [self.gan_weight, self.vgg_weight])

            var_g = tf.trainable_variables(self.name)
            var_d = tf.trainable_variables('Critic')
            var_df = tf.trainable_variables('DF')
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                opt_i = tf.train.AdamOptimizer(self.learning_rate).minimize(
                    mse, self.global_steps, var_list=var_g)
                opt_g = tf.train.AdamOptimizer(self.learning_rate).minimize(
                    loss, self.global_steps, var_list=var_g)
                opt_d = tf.train.AdamOptimizer(self.learning_rate).minimize(
                    loss_d, var_list=var_d + var_df)
                self.loss = [opt_i, opt_d, opt_g]

        self.train_metric['g_loss'] = loss_g
        self.train_metric['d_loss'] = loss_d
        self.train_metric['vgg_loss'] = loss_vgg
        self.train_metric['loss'] = loss
        self.metrics['psnr'] = tf.reduce_mean(
            tf.image.psnr(self.label[-1], self.outputs[-1], 255))
        self.metrics['ssim'] = tf.reduce_mean(
            tf.image.ssim(self.label[-1], self.outputs[-1], 255))
Beispiel #2
0
 def build_loss(self):
     with tf.name_scope('Loss'):
         g_loss, d_loss = loss_bce_gan(*self.d_outputs)
         with tf.variable_scope(self.name, reuse=True):
             gp = gradient_penalty(*self.g_outputs, self.D, lamb=10)
         d_loss += gp
         self._build_loss(g_loss, d_loss)
Beispiel #3
0
    def build_graph(self):
        super(SRGAN, self).build_graph()
        with tf.variable_scope(self.name):
            inputs_norm = self._normalize(self.inputs_preproc[-1])
            shallow_feature = self.prelu_conv2d(inputs_norm, 64, 9)
            x = shallow_feature
            for _ in range(self.g_layers):
                x = self.resblock(x,
                                  64,
                                  3,
                                  activation='prelu',
                                  use_batchnorm=True)
            x = self.bn_conv2d(x, 64, 3)
            x += shallow_feature
            x = self.conv2d(x, 256, 3)
            sr = self.upscale(x, direct_output=False, activator=prelu)
            sr = self.tanh_conv2d(sr, self.channel, 9)
            self.outputs.append(self._denormalize(sr))

        label_norm = self._normalize(self.label[-1])
        disc_real = self.D(label_norm)
        disc_fake = self.D(sr)

        with tf.name_scope('Loss'):
            loss_gen, loss_disc = loss_bce_gan(disc_real, disc_fake)
            mse = tf.losses.mean_squared_error(label_norm, sr)
            reg = tf.losses.get_regularization_losses()

            loss = tf.add_n(
                [mse * self.mse_weight, loss_gen * self.gan_weight] + reg)
            if self.use_vgg:
                vgg_real = self.vgg(self.label[-1], self.vgg_layer)
                vgg_fake = self.vgg(self.outputs[-1], self.vgg_layer)
                loss_vgg = tf.losses.mean_squared_error(
                    vgg_real, vgg_fake, self.vgg_weight)
                loss += loss_vgg

            var_g = tf.trainable_variables(self.name)
            var_d = tf.trainable_variables('Critic')
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                opt_i = tf.train.AdamOptimizer(self.learning_rate).minimize(
                    mse, self.global_steps, var_list=var_g)
                opt_g = tf.train.AdamOptimizer(self.learning_rate).minimize(
                    loss, self.global_steps, var_list=var_g)
                opt_d = tf.train.AdamOptimizer(self.learning_rate).minimize(
                    loss_disc, var_list=var_d)
                self.loss = [opt_i, opt_d, opt_g]

        self.train_metric['g_loss'] = loss_gen
        self.train_metric['d_loss'] = loss_disc
        self.train_metric['loss'] = loss
        self.metrics['mse'] = mse
        self.metrics['psnr'] = tf.reduce_mean(
            tf.image.psnr(self.label[-1], self.outputs[-1], 255))
        self.metrics['ssim'] = tf.reduce_mean(
            tf.image.ssim(self.label[-1], self.outputs[-1], 255))
Beispiel #4
0
 def build_loss(self):
     with tf.name_scope('Loss'):
         g_loss, d_loss = loss_bce_gan(*self.d_outputs)
         self._build_loss(g_loss, d_loss)