def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'): """ """ if FLAGS.guided: batch_data, edge = batch_data edge = edge[:, :, :, 0:1] / 255. edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32) mask = brush_stroke_mask(name='mask_c') regular_mask = bbox2mask(bbox, name='mask_c') irregular_mask = brush_stroke_mask(name='mask_c') mask = tf.cast( tf.logical_or( tf.cast(irregular_mask, tf.bool), tf.cast(regular_mask, tf.bool), ), tf.float32 ) batch_pos = batch_data / 127.5 - 1. batch_incomplete = batch_pos*(1.-mask) if FLAGS.guided: edge = edge * mask xin = tf.concat([batch_incomplete, edge], axis=3) else: xin = batch_incomplete # inpaint x1, x2, offset_flow = self.build_inpaint_net( xin, mask, reuse=True, training=False, padding=FLAGS.padding) batch_predicted = x2 # apply mask and reconstruct batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # global image visualization if FLAGS.guided: viz_img = [ batch_pos, batch_incomplete + edge, batch_complete] else: viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_bilinear)) images_summary( tf.concat(viz_img, axis=2), name+'_raw_incomplete_complete', FLAGS.viz_max_out) return batch_complete
def train_step(iter_idx): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: batch_pos = next(train_iter) bbox = random_bbox(FLAGS) regular_mask = bbox2mask(FLAGS, bbox, name='mask_c') irregular_mask = brush_stroke_mask(FLAGS, name='mask_c') mask = tf.cast( tf.logical_or( tf.cast(irregular_mask, tf.bool), tf.cast(regular_mask, tf.bool), ), tf.float32) # mask is 0-1 batch_incomplete = batch_pos * (1. - mask) xin = batch_incomplete x1, x2, offset_flow = G(xin, mask, training=True) batch_predicted = x2 losses = {} batch_complete = batch_predicted * \ mask + batch_incomplete * (1. - mask) losses['ae_loss'] = FLAGS['l1_loss_alpha'] * \ tf.reduce_mean(input_tensor=tf.abs(batch_pos - x1)) losses['ae_loss'] += FLAGS['l1_loss_alpha'] * \ tf.reduce_mean(input_tensor=tf.abs(batch_pos - x2)) batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) batch_pos_neg = tf.concat([ batch_pos_neg, tf.tile(mask, [FLAGS['batch_size'] * 2, 1, 1, 1]) ], axis=3) # SNGAN pos_neg = D(batch_pos_neg, training=True) pos, neg = tf.split(pos_neg, 2) g_loss, d_loss = gan_hinge_loss(pos, neg) losses['g_loss'] = g_loss losses['d_loss'] = d_loss losses['g_loss'] = FLAGS['gan_loss_alpha'] * losses['g_loss'] losses['g_loss'] += losses['ae_loss'] # return losses gradients_of_generator = gen_tape.gradient(losses['g_loss'], G.trainable_variables) gradients_of_discriminator = disc_tape.gradient( losses['d_loss'], D.trainable_variables) g_optimizer.apply_gradients( zip(gradients_of_generator, G.trainable_variables)) d_optimizer.apply_gradients( zip(gradients_of_discriminator, D.trainable_variables)) # tensorboard if iter_idx > 0 and iter_idx % FLAGS['viz_max_out'] == 0: with train_summary_writer.as_default(): tf.summary.scalar('g_loss', losses['g_loss'], step=iter_idx) tf.summary.scalar('d_loss', losses['d_loss'], step=iter_idx) img = tf.reshape(batch_complete[0], (-1, 256, 256, 3)) tf.summary.image("train", img, step=iter_idx)
def build_graph_with_losses( self, FLAGS, batch_data, training=True, summary=False, reuse=False): if FLAGS.guided: batch_data, edge = batch_data edge = edge[:, :, :, 0:1] / 255. edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32) batch_pos = batch_data / 127.5 - 1. # generate mask, 1 represents masked point bbox = random_bbox(FLAGS) regular_mask = bbox2mask(FLAGS, bbox, name='mask_c') irregular_mask = brush_stroke_mask(FLAGS, name='mask_c') mask = tf.cast( tf.logical_or( tf.cast(irregular_mask, tf.bool), tf.cast(regular_mask, tf.bool), ), tf.float32 ) batch_incomplete = batch_pos*(1.-mask) if FLAGS.guided: edge = edge * mask xin = tf.concat([batch_incomplete, edge], axis=3) else: xin = batch_incomplete x1, x2, offset_flow = self.build_inpaint_net( xin, mask, reuse=reuse, training=training, padding=FLAGS.padding) batch_predicted = x2 losses = {} # apply mask and complete image batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # local patches losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1)) losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2)) if summary: scalar_summary('losses/ae_loss', losses['ae_loss']) if FLAGS.guided: viz_img = [ batch_pos, batch_incomplete + edge, batch_complete] else: viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_bilinear)) images_summary( tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', FLAGS.viz_max_out) # gan batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) if FLAGS.gan_with_mask: batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3) if FLAGS.guided: # conditional GANs batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3) # wgan with gradient penalty if FLAGS.gan == 'sngan': pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse) pos, neg = tf.split(pos_neg, 2) g_loss, d_loss = gan_hinge_loss(pos, neg) losses['g_loss'] = g_loss losses['d_loss'] = d_loss else: raise NotImplementedError('{} not implemented.'.format(FLAGS.gan)) if summary: # summary the magnitude of gradients from different losses w.r.t. predicted image gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss'] if FLAGS.ae_loss: losses['g_loss'] += losses['ae_loss'] g_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses