def build_graph(self): super(SRFEAT, self).build_graph() inputs_norm = _normalize(self.inputs_preproc[-1]) label_norm = _normalize(self.label[-1]) with tf.variable_scope(self.name): shallow_feature = self.prelu_conv2d(inputs_norm, self.F, 9) x = [shallow_feature] for _ in range(self.g_layers): x.append( self.resblock(x[-1], self.F, 3, activation='prelu', use_batchnorm=True)) bottleneck = x[-1] for t in x[1:-1]: bottleneck += self.conv2d(t, self.F, 1) sr = self.upscale(bottleneck, direct_output=False, activator=prelu) sr = self.tanh_conv2d(sr, self.channel, 9) self.outputs.append(_denormalize(sr)) disc_real = self.D(label_norm) disc_fake = self.D(sr) vgg_features = [self.vgg(self.outputs[0], self.vgg_layer)] vgg_features += [self.vgg(self.label[0], self.vgg_layer)] vgg_fake = self.DF(vgg_features[0]) vgg_real = self.DF(vgg_features[1]) with tf.name_scope('Loss'): loss_gen, loss_disc = loss_bce_gan(disc_real, disc_fake) vgg_loss_g, vgg_loss_d = loss_bce_gan(vgg_real, vgg_fake) mse = tf.losses.mean_squared_error(label_norm, sr) loss_d = loss_disc + vgg_loss_d loss_g = loss_gen + vgg_loss_g loss_vgg = tf.losses.mean_squared_error(*vgg_features) loss = tf.stack([loss_g, loss_vgg]) loss = tf.reduce_sum(loss * [self.gan_weight, self.vgg_weight]) var_g = tf.trainable_variables(self.name) var_d = tf.trainable_variables('Critic') var_df = tf.trainable_variables('DF') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt_i = tf.train.AdamOptimizer(self.learning_rate).minimize( mse, self.global_steps, var_list=var_g) opt_g = tf.train.AdamOptimizer(self.learning_rate).minimize( loss, self.global_steps, var_list=var_g) opt_d = tf.train.AdamOptimizer(self.learning_rate).minimize( loss_d, var_list=var_d + var_df) self.loss = [opt_i, opt_d, opt_g] self.train_metric['g_loss'] = loss_g self.train_metric['d_loss'] = loss_d self.train_metric['vgg_loss'] = loss_vgg self.train_metric['loss'] = loss self.metrics['psnr'] = tf.reduce_mean( tf.image.psnr(self.label[-1], self.outputs[-1], 255)) self.metrics['ssim'] = tf.reduce_mean( tf.image.ssim(self.label[-1], self.outputs[-1], 255))
def build_loss(self): with tf.name_scope('Loss'): g_loss, d_loss = loss_bce_gan(*self.d_outputs) with tf.variable_scope(self.name, reuse=True): gp = gradient_penalty(*self.g_outputs, self.D, lamb=10) d_loss += gp self._build_loss(g_loss, d_loss)
def build_graph(self): super(SRGAN, self).build_graph() with tf.variable_scope(self.name): inputs_norm = self._normalize(self.inputs_preproc[-1]) shallow_feature = self.prelu_conv2d(inputs_norm, 64, 9) x = shallow_feature for _ in range(self.g_layers): x = self.resblock(x, 64, 3, activation='prelu', use_batchnorm=True) x = self.bn_conv2d(x, 64, 3) x += shallow_feature x = self.conv2d(x, 256, 3) sr = self.upscale(x, direct_output=False, activator=prelu) sr = self.tanh_conv2d(sr, self.channel, 9) self.outputs.append(self._denormalize(sr)) label_norm = self._normalize(self.label[-1]) disc_real = self.D(label_norm) disc_fake = self.D(sr) with tf.name_scope('Loss'): loss_gen, loss_disc = loss_bce_gan(disc_real, disc_fake) mse = tf.losses.mean_squared_error(label_norm, sr) reg = tf.losses.get_regularization_losses() loss = tf.add_n( [mse * self.mse_weight, loss_gen * self.gan_weight] + reg) if self.use_vgg: vgg_real = self.vgg(self.label[-1], self.vgg_layer) vgg_fake = self.vgg(self.outputs[-1], self.vgg_layer) loss_vgg = tf.losses.mean_squared_error( vgg_real, vgg_fake, self.vgg_weight) loss += loss_vgg var_g = tf.trainable_variables(self.name) var_d = tf.trainable_variables('Critic') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt_i = tf.train.AdamOptimizer(self.learning_rate).minimize( mse, self.global_steps, var_list=var_g) opt_g = tf.train.AdamOptimizer(self.learning_rate).minimize( loss, self.global_steps, var_list=var_g) opt_d = tf.train.AdamOptimizer(self.learning_rate).minimize( loss_disc, var_list=var_d) self.loss = [opt_i, opt_d, opt_g] self.train_metric['g_loss'] = loss_gen self.train_metric['d_loss'] = loss_disc self.train_metric['loss'] = loss self.metrics['mse'] = mse self.metrics['psnr'] = tf.reduce_mean( tf.image.psnr(self.label[-1], self.outputs[-1], 255)) self.metrics['ssim'] = tf.reduce_mean( tf.image.ssim(self.label[-1], self.outputs[-1], 255))
def build_loss(self): with tf.name_scope('Loss'): g_loss, d_loss = loss_bce_gan(*self.d_outputs) self._build_loss(g_loss, d_loss)