def G_paper(self, z, last_resolution, current_resolution, name='G_paper', reuse=False): """Build graph for generator. Returns: tensor of image """ assert last_resolution in [4, 8, 16, 32, 64, 128, 256, 512] assert current_resolution == last_resolution * 2 get_cnum = lambda x: int(min(MAX_C, 2 ** (13 - np.log2(x)))) x = z # with tf.variable_scope(name, reuse=(reuse or current_resolution!=8)): with tf.variable_scope(name, reuse=(reuse or False)): # [-1, 4, 4, 512] x = tf.reshape(x, [-1, 1, 1, 512]) x = tf.layers.conv2d_transpose( x, 512, 4, 4, padding="same", activation=act, name='deconv_in') x = tf.layers.conv2d( x, 512, 3, padding='same', activation=act, name='conv_in') block_resolution = 4 # with tf.variable_scope(name, reuse=True): with tf.variable_scope(name, reuse=(reuse or False)): for i in range(int(np.log2(current_resolution) - 3)): cnum = get_cnum(block_resolution) logger.info('Restore block, input resolution: {}, cnum: {}, ' 'output resolution: {}.'.format( block_resolution, cnum, block_resolution*2)) x = resize(x, 2) block_resolution *= 2 x = nn_block(x, cnum, name='block%s' % block_resolution) if current_resolution != 64: last_x = tf.layers.conv2d( x, 3, 1, padding='same', name='%s_out' % block_resolution) with tf.variable_scope(name, reuse=(reuse or False)): cnum = get_cnum(block_resolution) logger.info('Add block, input resolution: {}, cnum: {}, ' 'output resolution: {}.'.format( block_resolution, cnum, block_resolution*2)) x = resize(x, 2) block_resolution *= 2 x = nn_block(x, cnum, name='block%s' % block_resolution) x = tf.layers.conv2d( x, 3, 1, padding='same', name='%s_out' % block_resolution) kt = progressive_kt('%s_kt' % block_resolution) if current_resolution != 64: x = kt * x + (1. - kt) * resize(last_x, 2) return x
def gated_deconv(x, cnum, name='upsample', padding='SAME', training=True): with tf.variable_scope(name): x = resize(x, func=tf.image.resize_nearest_neighbor) x = gated_conv( x, cnum, 3, 1, name=name+'_conv', padding=padding, training=training) return x
def build_infer_graph(self, batch_data, config, bbox=None, name='val'): """ """ config.MAX_DELTA_HEIGHT = 0 config.MAX_DELTA_WIDTH = 0 if bbox is None: bbox = random_bbox(config) mask = bbox2mask(bbox, config, name=name + 'mask_c') batch_pos = batch_data / 127.5 - 1. edges = None batch_incomplete = batch_pos * (1. - mask) # inpaint x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, mask, config, reuse=True, training=False, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 else: batch_predicted = x2 # apply mask and reconstruct batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) # global image visualization viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) return batch_complete
def gen_deconv(x, cnum, ksize, stride, rate, name='upsample', padding='same', training=True): """Define deconv for generator. The deconv is defined to be a x2 resize_nearest_neighbor operation with additional gen_conv operation. Args: x: Input. cnum: Channel number. name: Name of layers. training: If current graph is for training or inference, used for bn. Returns: tf.Tensor: output """ # Just using these as options to keep the signature the same as gen_conv. assert ksize == 3 assert stride == 1 assert rate == None with tf.variable_scope(name): x = resize(x, func=tf.image.resize_nearest_neighbor) x, layer = gen_conv( x, cnum, 3, 1, name=name+'_conv', padding=padding, training=training) return x, layer
def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True): """Define deconv for generator. The deconv is defined to be a x2 resize_nearest_neighbor operation with additional gen_conv operation. Args: x: Input. cnum: Channel number. name: Name of layers. training: If current graph is for training or inference, used for bn. Returns: tf.Tensor: output """ with tf.variable_scope(name): x = resize(x, func=tf.image.resize_nearest_neighbor) x = gen_conv(x, cnum, 3, 1, name=name + '_conv', padding=padding, training=training) return x
def build_infer_graph(self, batch_data, config, name='val'): config.MAX_DELTA_HEIGHT = 0 config.MAX_DELTA_WIDTH = 0 mask = bbox2mask(config, name=name + 'mask_c') batch_pos = batch_data / 127.5 - 1. edges = None batch_incomplete = batch_pos * (1. - mask) x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, mask, config, reuse=True, training=False, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary(tf.concat(viz_img, axis=2), name + '_raw_incomplete_complete', config.VIZ_MAX_OUT) return batch_complete
def build_infer_graph(self, batch_data, batch_mask, batch_guide, config, name='val'): """ validation """ config.MAX_DELTA_HEIGHT = 0 config.MAX_DELTA_WIDTH = 0 batch_pos = batch_data / 127.5 - 1. batch_incomplete = batch_pos*(1.-batch_mask) # inpaint x1, x2, offset_flow = self.build_inpaint_net( batch_incomplete, batch_mask, batch_guide, config, reuse=True, training=False, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') # apply mask and reconstruct batch_complete = batch_predicted*batch_mask + batch_incomplete*(1.-batch_mask) # global image visualization viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary( tf.concat(viz_img, axis=2), name+'_raw_incomplete_complete', config.VIZ_MAX_OUT) return batch_complete
def build_infer_graph(self, batch_data, config, bbox=None, name='val'): """ """ config.MAX_DELTA_HEIGHT = 0 config.MAX_DELTA_WIDTH = 0 if bbox is None: bbox = random_bbox(config) mask = bbox2mask(bbox, config, name=name+'mask_c') batch_pos = batch_data / 127.5 - 1. edges = None batch_incomplete = batch_pos*(1.-mask) # inpaint x2, offset_flow = self.build_inpaint_net( batch_incomplete, mask, config, reuse=True, training=False, padding=config.PADDING) batch_predicted = x2 logger.info('Set batch_predicted to x2.') # apply mask and reconstruct batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # global image visualization viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary( tf.concat(viz_img, axis=2), name+'_raw_incomplete_complete', config.VIZ_MAX_OUT) return batch_complete
def downsampleMask_graph(self, FLAGS, batch_data, mask, downsample_rate, name='val'): if FLAGS.guided: batch_data, edge = batch_data edge = edge[:, :, :, 0:1] / 255. edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32) # mask = brush_stroke_mask(name='mask_c') # regular_mask = bbox2mask(bbox, name='mask_c') # irregular_mask = brush_stroke_mask(name='mask_c') # mask = tf.cast( # tf.logical_or( # tf.cast(irregular_mask, tf.bool), # tf.cast(regular_mask, tf.bool), # ), # tf.float32 # ) # 事实上还是对输入的图像做了归一化 这里要改 # modified # batch_pos = batch_data / 127.5 - 1. mean = tf.reduce_mean(tf.reduce_mean(batch_data * (1. - mask), 1), 2) * downsample_rate * downsample_rate batch_pos = batch_data / mean - 1 batch_incomplete = batch_pos * (1. - mask) if FLAGS.guided: edge = edge * mask xin = tf.concat([batch_incomplete, edge], axis=3) else: xin = batch_incomplete # inpaint x1, x2, offset_flow = self.build_inpaint_net(xin, mask, reuse=True, training=False, padding=FLAGS.padding) batch_predicted = x2 # apply mask and reconstruct batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) # global image visualization if FLAGS.guided: viz_img = [batch_pos, batch_incomplete + edge, batch_complete] else: viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_bilinear)) images_summary(tf.concat(viz_img, axis=2), name + '_raw_incomplete_complete', FLAGS.viz_max_out) return batch_complete
def build_infer_graph(self, batch_data, config, bbox=None, name='val', exclusionmask=None): """ """ config.MAX_DELTA_HEIGHT = 0 config.MAX_DELTA_WIDTH = 0 if bbox is None: bbox = random_bbox(config) mask = bbox2mask(bbox, config, name=name + 'mask_c') batch_pos = batch_data / 127.5 - 1. edges = None batch_incomplete = batch_pos * (1. - mask) # inpaint x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, mask, config, reuse=True, training=False, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') # apply mask and reconstruct batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) # global image visualization img_size = [dim for dim in batch_incomplete.shape] img_size[2] = 5 border = tf.zeros(tf.TensorShape(img_size)) viz_img = [ border, batch_pos, border, batch_incomplete, border, batch_complete, border ] if not config.PRETRAIN_COARSE_NETWORK: batch_complete_coarse = x1 * mask + batch_incomplete * (1. - mask) viz_img.append(batch_complete_coarse) if offset_flow is not None: scale = 2 << len(offset_flow) for flow in offset_flow: viz_img.append( resize(flow, scale=scale, func=tf.image.resize_nearest_neighbor)) viz_img.append(border) scale >>= 1 images_summary(tf.concat(viz_img, axis=2), name + '_raw_incomplete_complete', config.VIZ_MAX_OUT) return batch_complete
def resize_mask_like(mask, x): """Resize mask like shape of x. Args: mask: Original mask. x: To shape of x. Returns: tf.Tensor: resized mask """ mask_resize = resize( mask, to_shape=x.get_shape().as_list()[1:3], func=tf.image.resize_nearest_neighbor) #Jaya's Code mask_resize = resize( mask, scale=1./4, dynamic=True) return mask_resize
def build_infer_graph(self, batch_data, config, bbox=None, name='val'): """ """ config.MAX_DELTA_HEIGHT = 0 config.MAX_DELTA_WIDTH = 0 if bbox is None: bbox = random_bbox(config) mask = bbox2mask(bbox, config, name=name + 'mask_c') batch_pos = batch_data / 127.5 - 1. edges = None batch_incomplete = batch_pos * (1. - mask) # inpaint if not config.ADD_GRADIENT_BRANCH: x1, x2, offset_flow = self.build_inpaint_net( batch_incomplete, mask, config, reuse=True, training=False, padding=config.PADDING) else: x1, x2, fake_gb, gb, offset_flow = self.build_inpaint_net( batch_incomplete, mask, config, reuse=True, training=False, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') # apply mask and reconstruct batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) # global image visualization viz_img = [batch_pos, batch_incomplete, batch_complete] if config.ADD_GRADIENT_BRANCH: gb_gt = self.get_grad.get_gradient_tf(batch_pos) batch_complete_gb = fake_gb * mask + gb_gt * (1. - mask) viz_img.extend([gb_gt, batch_complete_gb]) if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary(tf.concat(viz_img, axis=2), name + '_raw_incomplete_complete', config.VIZ_MAX_OUT) return batch_complete
def resize_mask_like(mask, x): """Resize mask like shape of x. Args: mask: Original mask. x: To shape of x. Returns: tf.Tensor: resized mask """ mask_resize = resize(mask, to_shape=x.get_shape().as_list()[1:3], func=tf.compat.v1.image.resize_nearest_neighbor) return mask_resize
def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'): """ """ if FLAGS.guided: batch_data, edge = batch_data edge = edge[:, :, :, 0:1] / 255. edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32) mask = brush_stroke_mask(name='mask_c') regular_mask = bbox2mask(bbox, name='mask_c') irregular_mask = brush_stroke_mask(name='mask_c') mask = tf.cast( tf.logical_or( tf.cast(irregular_mask, tf.bool), tf.cast(regular_mask, tf.bool), ), tf.float32 ) batch_pos = batch_data / 127.5 - 1. batch_incomplete = batch_pos*(1.-mask) if FLAGS.guided: edge = edge * mask xin = tf.concat([batch_incomplete, edge], axis=3) else: xin = batch_incomplete # inpaint x1, x2, offset_flow = self.build_inpaint_net( xin, mask, reuse=True, training=False, padding=FLAGS.padding) batch_predicted = x2 # apply mask and reconstruct batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # global image visualization if FLAGS.guided: viz_img = [ batch_pos, batch_incomplete + edge, batch_complete] else: viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_bilinear)) images_summary( tf.concat(viz_img, axis=2), name+'_raw_incomplete_complete', FLAGS.viz_max_out) return batch_complete
def build_infer_graph(self, batch_data, batch_mask, batch_guide, config, name='val'): """ validation """ batch_pos = batch_data / 127.5 - 1. if batch_mask is None: batch_mask = random_ff_mask(parts=8) else: pass batch_incomplete = batch_pos * (1. - batch_mask) ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1] batch_mask = ones_x * batch_mask batch_guide = ones_x x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, batch_mask, batch_guide, config, reuse=True, training=False, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') # apply mask and complete image batch_complete = batch_predicted * batch_mask + batch_incomplete * ( 1. - batch_mask) # global image visualization viz_img = [ batch_pos, batch_incomplete, batch_predicted, batch_complete ] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary(tf.concat(viz_img, axis=2), name + '_raw_incomplete_complete', config.VIZ_MAX_OUT) return batch_complete
def resize_mask_like(mask, x): """Resize mask like shape of x. Args: mask: Original mask. x: To shape of x. Returns: tf.Tensor: resized mask """ print('*******************\n***************resize_mask_like***************\n************************') mask_resize = resize( mask, to_shape=x.get_shape().as_list()[1:3], func=tf.image.resize_nearest_neighbor) print("resized mask_s",(mask_resize.shape.dims)) return mask_resize
def build_graph_with_losses( self, FLAGS, batch_data, training=True, summary=False, reuse=False): if FLAGS.guided: batch_data, edge = batch_data edge = edge[:, :, :, 0:1] / 255. edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32) batch_pos = batch_data / 127.5 - 1. # generate mask, 1 represents masked point bbox = random_bbox(FLAGS) regular_mask = bbox2mask(FLAGS, bbox, name='mask_c') irregular_mask = brush_stroke_mask(FLAGS, name='mask_c') mask = tf.cast( tf.logical_or( tf.cast(irregular_mask, tf.bool), tf.cast(regular_mask, tf.bool), ), tf.float32 ) batch_incomplete = batch_pos*(1.-mask) if FLAGS.guided: edge = edge * mask xin = tf.concat([batch_incomplete, edge], axis=3) else: xin = batch_incomplete x1, x2, offset_flow = self.build_inpaint_net( xin, mask, reuse=reuse, training=training, padding=FLAGS.padding) batch_predicted = x2 losses = {} # apply mask and complete image batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # local patches losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1)) losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2)) if summary: scalar_summary('losses/ae_loss', losses['ae_loss']) if FLAGS.guided: viz_img = [ batch_pos, batch_incomplete + edge, batch_complete] else: viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_bilinear)) images_summary( tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', FLAGS.viz_max_out) # gan batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) if FLAGS.gan_with_mask: batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3) if FLAGS.guided: # conditional GANs batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3) # wgan with gradient penalty if FLAGS.gan == 'sngan': pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse) pos, neg = tf.split(pos_neg, 2) g_loss, d_loss = gan_hinge_loss(pos, neg) losses['g_loss'] = g_loss losses['d_loss'] = d_loss else: raise NotImplementedError('{} not implemented.'.format(FLAGS.gan)) if summary: # summary the magnitude of gradients from different losses w.r.t. predicted image gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss'] if FLAGS.ae_loss: losses['g_loss'] += losses['ae_loss'] g_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses
def build_graph_with_losses(self, batch_data, batch_mask, batch_guide, config, training=True, summary=False, reuse=False): batch_pos = batch_data / 127.5 - 1. batch_mask = random_ff_mask(parts=8) batch_incomplete = batch_pos * (1. - batch_mask) ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1] batch_mask = ones_x * batch_mask batch_guide = ones_x x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, batch_mask, batch_guide, config, reuse=reuse, training=training, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') losses = {} # apply mask and complete image batch_complete = batch_predicted * batch_mask + batch_incomplete * ( 1. - batch_mask) # Local patch is removed in the gated convolution. It's now AE + GAN loss # (https://github.com/JiahuiYu/generative_inpainting/issues/62) losses['ae_loss'] = config.COARSE_L1_ALPHA * tf.reduce_mean( tf.abs(batch_pos - x1) * (1. - batch_mask)) losses['ae_loss'] += config.COARSE_L1_ALPHA * tf.reduce_mean( tf.abs(batch_pos - x2) * (1. - batch_mask)) losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask) if summary: scalar_summary('losses/ae_loss', losses['ae_loss']) viz_img = [ batch_pos, batch_incomplete, batch_predicted, batch_complete ] # I have included the predicted image as well to see the reconstructed image. if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary(tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) # gan batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) batch_mask_all = tf.tile(batch_mask, [config.BATCH_SIZE * 2, 1, 1, 1]) if config.GAN_WITH_MASK: batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3) if config.GAN_WITH_GUIDE: batch_pos_neg = tf.concat([ batch_pos_neg, tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1]) ], axis=3) # sn-pgan if config.GAN == 'sn_pgan': # sn path gan pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg, training=training, reuse=reuse) pos_global, neg_global = tf.split(pos_neg, 2) # SNPGAN Loss g_loss_global, d_loss_global = gan_sn_pgan_loss( pos_global, neg_global, name='gan/global_gan') losses['g_loss'] = g_loss_global losses['d_loss'] = d_loss_global if summary and not config.PRETRAIN_COARSE_NETWORK: # summary the magnitude of gradients from different losses w.r.t. predicted image gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['d_loss'], batch_predicted, name='d_loss') gradients_summary(losses['ae_loss'], x1, name='ae_loss_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_x2') losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses
def build_graph_with_losses(self, batch_data, config, training=True, summary=False, reuse=False): batch_pos = batch_data / 127.5 - 1. mask = bbox2mask(config, name='mask_c') batch_incomplete = batch_pos * (1. - mask) x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, mask, config, reuse=reuse, training=training, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') losses = {} batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) local_patch_batch_pos = local_patch(batch_pos, mask) local_patch_batch_predicted = local_patch(batch_predicted, mask) local_patch_x1 = local_patch(x1, mask) local_patch_x2 = local_patch(x2, mask) local_patch_batch_complete = local_patch(batch_complete, mask) l1_alpha = config.COARSE_L1_ALPHA losses['l1_loss'] = l1_alpha * tf.reduce_mean( tf.abs(local_patch_batch_pos - local_patch_x1)) #*spatial_discounting_mask(config)) if not config.PRETRAIN_COARSE_NETWORK: losses['l1_loss'] += tf.reduce_mean( tf.abs(local_patch_batch_pos - local_patch_x2)) #*spatial_discounting_mask(config)) losses['ae_loss'] = l1_alpha * tf.reduce_mean( tf.abs(batch_pos - x1) * (1. - mask)) if not config.PRETRAIN_COARSE_NETWORK: losses['ae_loss'] += tf.reduce_mean( tf.abs(batch_pos - x2) * (1. - mask)) losses['ae_loss'] /= tf.reduce_mean(1. - mask) if summary: scalar_summary('losses/l1_loss', losses['l1_loss']) scalar_summary('losses/ae_loss', losses['ae_loss']) viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary(tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) local_patch_batch_pos_neg = tf.concat( [local_patch_batch_pos, local_patch_batch_complete], 0) if config.GAN_WITH_MASK: batch_pos_neg = tf.concat([ batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1]) ], axis=3) if config.GAN == 'snpatch_gan': pos_neg = self.build_SNGAN_discriminator(local_patch_batch_pos_neg, training=training, reuse=reuse) pos, neg = tf.split(pos_neg, 2) sn_gloss, sn_dloss = self.gan_hinge_loss(pos, neg, name="gan/hinge_loss") losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * sn_gloss losses['d_loss'] = sn_dloss interpolates = random_interpolates(a1, a2) dout = self.build_SNGAN_discriminator(interpolates, reuse=True) penalty = gradients_penalty(interpolates, dout, mask=mask) losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty losses['d_loss'] = losses['d_loss'] + losses['gp_loss'] if summary and not config.PRETRAIN_COARSE_NETWORK: gradients_summary(sn_gloss, batch_predicted, name='g_loss_local') scalar_summary('convergence/d_loss', losses['d_loss']) scalar_summary('convergence/local_d_loss', sn_dloss) scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss']) scalar_summary('gan_wgan_loss/gp_penalty_local', penalty) if summary and not config.PRETRAIN_COARSE_NETWORK: gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1') gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1') gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2') gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') if config.PRETRAIN_COARSE_NETWORK: losses['g_loss'] = 0 else: losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss'] logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA) logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) if config.AE_LOSS: losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA) g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses
def build_graph_with_losses(self, batch_data, config, training=True, summary=False, reuse=False): batch_pos = batch_data / 127.5 - 1. # generate mask, 1 represents masked point bbox = random_bbox(config) mask = bbox2mask(bbox, config, name='mask_c') batch_incomplete = batch_pos*(1.-mask) x1, x2, offset_flow = self.build_inpaint_net( batch_incomplete, mask, config, reuse=reuse, training=training, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') losses = {} # apply mask and complete image batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) # local patches local_patch_batch_pos = local_patch(batch_pos, bbox) local_patch_batch_predicted = local_patch(batch_predicted, bbox) local_patch_x1 = local_patch(x1, bbox) local_patch_x2 = local_patch(x2, bbox) local_patch_batch_complete = local_patch(batch_complete, bbox) local_patch_mask = local_patch(mask, bbox) l1_alpha = config.COARSE_L1_ALPHA losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config)) if not config.PRETRAIN_COARSE_NETWORK: losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)*spatial_discounting_mask(config)) losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask)) if not config.PRETRAIN_COARSE_NETWORK: losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-mask)) losses['ae_loss'] /= tf.reduce_mean(1.-mask) if summary: scalar_summary('losses/l1_loss', losses['l1_loss']) scalar_summary('losses/ae_loss', losses['ae_loss']) viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary( tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) # gan batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) # local deterministic patch local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0) if config.GAN_WITH_MASK: batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3) # wgan with gradient penalty if config.GAN == 'wgan_gp': # seperate gan pos_neg_local, pos_neg_global = self.build_wgan_discriminator(local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse) pos_local, neg_local = tf.split(pos_neg_local, 2) pos_global, neg_global = tf.split(pos_neg_global, 2) # wgan loss g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan') g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan') losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local losses['d_loss'] = d_loss_global + d_loss_local # gp interpolates_local = random_interpolates(local_patch_batch_pos, local_patch_batch_complete) interpolates_global = random_interpolates(batch_pos, batch_complete) dout_local, dout_global = self.build_wgan_discriminator( interpolates_local, interpolates_global, reuse=True) # apply penalty penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask) penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask) losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global) losses['d_loss'] = losses['d_loss'] + losses['gp_loss'] if summary and not config.PRETRAIN_COARSE_NETWORK: gradients_summary(g_loss_local, batch_predicted, name='g_loss_local') gradients_summary(g_loss_global, batch_predicted, name='g_loss_global') scalar_summary('convergence/d_loss', losses['d_loss']) scalar_summary('convergence/local_d_loss', d_loss_local) scalar_summary('convergence/global_d_loss', d_loss_global) scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss']) scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local) scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global) if summary and not config.PRETRAIN_COARSE_NETWORK: # summary the magnitude of gradients from different losses w.r.t. predicted image gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1') gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1') gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2') gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') if config.PRETRAIN_COARSE_NETWORK: losses['g_loss'] = 0 else: losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss'] logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA) logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) if config.AE_LOSS: losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA) g_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10., training=True, fuse=True): """ Contextual attention layer implementation. Contextual attention is first introduced in publication: Generative Image Inpainting with Contextual Attention, Yu et al. Args: x: Input feature to match (foreground). t: Input feature for match (background). mask: Input mask for t, indicating patches not available. ksize: Kernel size for contextual attention. stride: Stride for extracting patches from t. rate: Dilation for matching. softmax_scale: Scaled softmax for attention. training: Indicating if current graph is training or inference. Returns: tf.Tensor: output """ # get shapes raw_fs = tf.shape(f) raw_int_fs = f.get_shape().as_list() raw_int_bs = b.get_shape().as_list() # extract patches from background with stride and rate kernel = 2 * rate raw_w = tf.extract_image_patches(b, [1, kernel, kernel, 1], [1, rate * stride, rate * stride, 1], [1, 1, 1, 1], padding='SAME') raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]]) raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw # downscaling foreground option: downscaling both foreground and # background for matching and use original background for reconstruction. f = resize(f, scale=1. / rate, func=tf.image.resize_nearest_neighbor) b = resize(b, to_shape=[int(raw_int_bs[1] / rate), int(raw_int_bs[2] / rate)], func=tf.image.resize_nearest_neighbor ) # https://github.com/tensorflow/tensorflow/issues/11651 if mask is not None: mask = resize(mask, scale=1. / rate, func=tf.image.resize_nearest_neighbor) fs = tf.shape(f) int_fs = f.get_shape().as_list() f_groups = tf.split(f, int_fs[0], axis=0) # from t(H*W*C) to w(b*k*k*c*h*w) bs = tf.shape(b) int_bs = b.get_shape().as_list() w = tf.extract_image_patches(b, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]]) w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw # process mask if mask is None: mask = tf.zeros([1, bs[1], bs[2], 1]) m = tf.extract_image_patches(mask, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') m = tf.reshape(m, [1, -1, ksize, ksize, 1]) m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw m = m[0] mm = tf.cast( tf.equal(tf.reduce_mean(m, axis=[0, 1, 2], keep_dims=True), 0.), tf.float32) w_groups = tf.split(w, int_bs[0], axis=0) raw_w_groups = tf.split(raw_w, int_bs[0], axis=0) y = [] offsets = [] k = fuse_k scale = softmax_scale fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1]) for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): # conv for compare wi = wi[0] wi_normed = wi / tf.maximum( tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4) yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME") # conv implementation for fuse scores to encourage large patches if fuse: yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1]) yi = tf.nn.conv2d(yi, fuse_weight, strides=[1, 1, 1, 1], padding='SAME') yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]]) yi = tf.transpose(yi, [0, 2, 1, 4, 3]) yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1]) yi = tf.nn.conv2d(yi, fuse_weight, strides=[1, 1, 1, 1], padding='SAME') yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]]) yi = tf.transpose(yi, [0, 2, 1, 4, 3]) yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1] * bs[2]]) # softmax to match yi *= mm # mask yi = tf.nn.softmax(yi * scale, 3) yi *= mm # mask offset = tf.argmax(yi, axis=3, output_type=tf.int32) offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1) # deconv for patch pasting # 3.1 paste center wi_center = raw_wi[0] yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1, rate, rate, 1]) / 4. y.append(yi) offsets.append(offset) y = tf.concat(y, axis=0) y.set_shape(raw_int_fs) offsets = tf.concat(offsets, axis=0) offsets.set_shape(int_bs[:3] + [2]) # case1: visualize optical flow: minus current position h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1]) w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1]) offsets = offsets - tf.concat([h_add, w_add], axis=3) # to flow image flow = flow_to_image_tf(offsets) # # case2: visualize which pixels are attended # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32)) if rate != 1: flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor) return y, flow
def resize_mask_like(mask, x): mask_resize = resize(mask, to_shape=x.get_shape().as_list()[1:3], func=tf.image.resize_nearest_neighbor) return mask_resize
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10., training=True, fuse=True): raw_fs = tf.shape(f) raw_int_fs = f.get_shape().as_list() raw_int_bs = b.get_shape().as_list() kernel = 2 * rate raw_w = tf.extract_image_patches(b, [1, kernel, kernel, 1], [1, rate * stride, rate * stride, 1], [1, 1, 1, 1], padding='SAME') raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]]) raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) f = resize(f, scale=1. / rate, func=tf.image.resize_nearest_neighbor) b = resize(b, to_shape=[int(raw_int_bs[1] / rate), int(raw_int_bs[2] / rate)], func=tf.image.resize_nearest_neighbor) if mask is not None: mask = resize(mask, scale=1. / rate, func=tf.image.resize_nearest_neighbor) fs = tf.shape(f) int_fs = f.get_shape().as_list() f_groups = tf.split(f, int_fs[0], axis=0) bs = tf.shape(b) int_bs = b.get_shape().as_list() w = tf.extract_image_patches(b, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]]) w = tf.transpose(w, [0, 2, 3, 4, 1]) if mask is None: mask = tf.zeros([1, bs[1], bs[2], 1]) m = tf.extract_image_patches(mask, [1, ksize, ksize, 1], [1, stride, stride, 1], [1, 1, 1, 1], padding='SAME') m = tf.reshape(m, [1, -1, ksize, ksize, 1]) m = tf.transpose(m, [0, 2, 3, 4, 1]) m = m[0] mm = tf.cast( tf.equal(tf.reduce_mean(m, axis=[0, 1, 2], keep_dims=True), 0.), tf.float32) w_groups = tf.split(w, int_bs[0], axis=0) raw_w_groups = tf.split(raw_w, int_bs[0], axis=0) y = [] offsets = [] k = fuse_k scale = softmax_scale fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1]) for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): wi = wi[0] wi_normed = wi / tf.maximum( tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4) yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME") if fuse: yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1]) yi = tf.nn.conv2d(yi, fuse_weight, strides=[1, 1, 1, 1], padding='SAME') yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]]) yi = tf.transpose(yi, [0, 2, 1, 4, 3]) yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1]) yi = tf.nn.conv2d(yi, fuse_weight, strides=[1, 1, 1, 1], padding='SAME') yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]]) yi = tf.transpose(yi, [0, 2, 1, 4, 3]) yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1] * bs[2]]) yi *= mm yi = tf.nn.softmax(yi * scale, 3) yi *= mm offset = tf.argmax(yi, axis=3, output_type=tf.int32) offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1) wi_center = raw_wi[0] yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1, rate, rate, 1]) / 4. y.append(yi) offsets.append(offset) y = tf.concat(y, axis=0) y.set_shape(raw_int_fs) offsets = tf.concat(offsets, axis=0) offsets.set_shape(int_bs[:3] + [2]) h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1]) w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1]) offsets = offsets - tf.concat([h_add, w_add], axis=3) flow = flow_to_image_tf(offsets) if rate != 1: flow = resize(flow, scale=rate, func=tf.image.resize_bilinear) return y, flow
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10., training=True, fuse=True): """ Contextual attention layer implementation. Contextual attention is first introduced in publication: Generative Image Inpainting with Contextual Attention, Yu et al. Args: x: Input feature to match (foreground). t: Input feature for match (background). mask: Input mask for t, indicating patches not available. ksize: Kernel size for contextual attention. stride: Stride for extracting patches from t. rate: Dilation for matching. softmax_scale: Scaled softmax for attention. training: Indicating if current graph is training or inference. Returns: tf.Tensor: output """ # get shapes raw_fs = tf.shape(f) print("raw_fs",raw_fs.shape) raw_int_fs = f.get_shape().as_list() print("raw_int_fs",raw_int_fs) #foreground shape raw_int_bs = b.get_shape().as_list() print("raw_int_bs",raw_int_bs) #background shape ''' raw_fs (4,) raw_int_fs [2, 64, 64, 128] raw_int_bs [2, 64, 64, 128] ''' # extract patches from background with stride and rate kernel = 2*rate raw_w = tf.extract_image_patches( b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME') print("raw_w or extracted patches",raw_w) raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]]) raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) #####background patches. These patches have to be the kernels print("transposed raw_w or extracted patches",raw_w.shape) # transpose to b*k*k*c*hw # downscaling foreground option: downscaling both foreground and # background for matching and use original background for reconstruction. ######foreground and back ground need to be downscaled because? ### both are downscaled in same shape f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor) print("rdownscaled f",f.shape) b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor) # https://github.com/tensorflow/tensorflow/issues/11651 print("rdownscaled b",b.shape) ''' rdownscaled f (2, 32, 32, 128) rdownscaled b (2, 32, 32, 128) ''' ####### if mask is not None: mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor) print("again resized mask",mask.shape) #again resized mask (1, 32, 32, 1) #########after downscaling fs = tf.shape(f) print("fs:",fs) #int_fs: [2, 32, 32, 128] int_fs = f.get_shape().as_list() print("int_fs:",int_fs) f_groups = tf.split(f, int_fs[0], axis=0) print("splitted f_groups:",f_groups) #tf.split(X, row = n, column = m) is used to split the data set of the variable into n number of pieces row wise and m numbers of pieces column wise. #For example, we have data_set x of size (10,10), then tf.split(x, 2, 0) will break the data_set of x in 2 set of size (5, 10) #but if we take tf.split(x, 2, 2), then we will get 4 sets of data of size (5, 5). # from t(H*W*C) to w(b*k*k*c*h*w) bs = tf.shape(b) print("bs:",bs) int_bs = b.get_shape().as_list() print("int_bs:",int_bs) w = tf.extract_image_patches( b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME') ''' w or extracted patches (2, 32, 32, 1152) transposed w or extracted patches (2, 3, 3, 128, 1024) extracted mask patch shape (1, 32, 32, 9) transposed extracted mask patch shape (3, 3, 1, 1024) temporary (1, 1, 1, 1024) mm shape (1, 1, 1, 1024) splitted w_groups: [<tf.Tensor 'inpaint_net/split_1:0' shape=(1, 3, 3, 128, 1024) dtype=float32>, <tf.Tensor 'inpaint_net/split_1:1' shape=(1, 3, 3, 128, 1024) dtype=float32>] splitted raw_w_groups: [<tf.Tensor 'inpaint_net/split_2:0' shape=(1, 4, 4, 128, 1024) dtype=float32>, <tf.Tensor 'inpaint_net/split_2:1' shape=(1, 4, 4, 128, 1024) dtype=float32>] yi shape (1, 32, 32, 1024) yi shape after multiplying mm Tensor("inpaint_net/mul_8:0", shape=(1, 32, 32, 1024), dtype=float32) yi after softmax shape (1, 32, 32, 1024) yi shape (1, 32, 32, 1024) yi shape after multiplying mm Tensor("inpaint_net/mul_16:0", shape=(1, 32, 32, 1024), dtype=float32) yi after softmax shape (1, 32, 32, 1024) x_hallu [Dimension(2), Dimension(256), Dimension(256), Dimension(3)] ''' print("w or extracted patches",w.shape) w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]]) w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw print("transposed w or extracted patches",w.shape) ###again # process mask if mask is None: mask = tf.zeros([1, bs[1], bs[2], 1]) m = tf.extract_image_patches( mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME') print("extracted mask patch shape",m.shape) m = tf.reshape(m, [1, -1, ksize, ksize, 1]) m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw m = m[0] print("transposed extracted mask patch shape",m.shape) temporary=tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.) print('temporary',temporary.shape) mm = tf.cast(temporary, tf.float32) print("mm shape",mm.shape) w_groups = tf.split(w, int_bs[0], axis=0) raw_w_groups = tf.split(raw_w, int_bs[0], axis=0) print("splitted w_groups:",w_groups) print("splitted raw_w_groups:",raw_w_groups) y = [] offsets = [] k = fuse_k scale = softmax_scale fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1]) for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): # conv for compare wi = wi[0] wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4) ## normalize each background patch yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME") # conv implementation for fuse scores to encourage large patches if fuse: yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1]) yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME') yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]]) yi = tf.transpose(yi, [0, 2, 1, 4, 3]) yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1]) yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME') yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]]) yi = tf.transpose(yi, [0, 2, 1, 4, 3]) yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]]) #pr=tf.Print(yi,[yi],"Jisa says:") print("yi shape",yi.shape) # softmax to match yi *= mm # mask print("yi shape after multiplying mm",yi) yi = tf.nn.softmax(yi*scale, 3) #pr=tf.Print(yi,[yi],"Jisa says:") yi *= mm # mask print("yi after softmax shape",yi.shape) offset = tf.argmax(yi, axis=3, output_type=tf.int32) offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1) # deconv for patch pasting # 3.1 paste center wi_center = raw_wi[0] yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4. y.append(yi) offsets.append(offset) y = tf.concat(y, axis=0) y.set_shape(raw_int_fs) offsets = tf.concat(offsets, axis=0) offsets.set_shape(int_bs[:3] + [2]) # case1: visualize optical flow: minus current position h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1]) w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1]) offsets = offsets - tf.concat([h_add, w_add], axis=3) # to flow image flow = flow_to_image_tf(offsets) # # case2: visualize which pixels are attended # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32)) if rate != 1: flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor) return y, flow
def build_graph_with_losses(self, batch_data, batch_mask, batch_guide, config, training=True, summary=False, reuse=False): batch_pos = batch_data / 127.5 - 1. # generate mask, 1 represents masked point[] #print(batch_data, batch_mask) if batch_mask is None: batch_mask = random_ff_mask(config) else: pass #batch_mask = tf.reshape(batch_mask[0], [1, *batch_mask.get_shape().as_list()[1:]]) #print(batch_mask.shape) #rint() batch_incomplete = batch_pos * (1. - batch_mask) ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1] batch_mask = ones_x * batch_mask batch_guide = ones_x x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete, batch_mask, batch_mask, config, reuse=reuse, training=training, padding=config.PADDING) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') losses = {} # apply mask and complete image batch_complete = batch_predicted * batch_mask + batch_incomplete * ( 1. - batch_mask) # local patches local_patch_batch_pos = mask_patch(batch_pos, batch_mask) local_patch_batch_predicted = mask_patch(batch_predicted, batch_mask) local_patch_x1 = mask_patch(x1, batch_mask) local_patch_x2 = mask_patch(x2, batch_mask) local_patch_batch_complete = mask_patch(batch_complete, batch_mask) #local_patch_mask = mask_patch(mask, bbox) # local patch l1 loss hole+out same as partial convolution l1_alpha = config.COARSE_L1_ALPHA losses['l1_loss'] = l1_alpha * tf.reduce_mean( tf.abs(local_patch_batch_pos - local_patch_x1)) # *spatial_discounting_mask(config)) if not config.PRETRAIN_COARSE_NETWORK: losses['l1_loss'] += tf.reduce_mean( tf.abs(local_patch_batch_pos - local_patch_x2)) # *spatial_discounting_mask(config)) losses['ae_loss'] = l1_alpha * tf.reduce_mean( tf.abs(batch_pos - x1) * (1. - batch_mask)) if not config.PRETRAIN_COARSE_NETWORK: losses['ae_loss'] += tf.reduce_mean( tf.abs(batch_pos - x2) * (1. - batch_mask)) losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask) if summary: scalar_summary('losses/l1_loss', losses['l1_loss']) scalar_summary('losses/ae_loss', losses['ae_loss']) viz_img = [batch_pos, batch_incomplete, batch_complete] if offset_flow is not None: viz_img.append( resize(offset_flow, scale=4, func=tf.image.resize_nearest_neighbor)) images_summary(tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) # gan batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) if config.MASKFROMFILE: batch_mask_all = tf.tile(batch_mask, [2, 1, 1, 1]) #batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1]) else: batch_mask_all = tf.tile(batch_mask, [config.BATCH_SIZE * 2, 1, 1, 1]) batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1]) if config.GAN_WITH_MASK: batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3) if config.GAN_WITH_GUIDE: batch_pos_neg = tf.concat([ batch_pos_neg, tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1]) ], axis=3) #batch_pos_, batch_complete_ = tf.split(axis, value, num_split, name=None) # sn-pgan with gradient penalty if config.GAN == 'sn_pgan': # sn path gan pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg, training=training, reuse=reuse) pos_global, neg_global = tf.split(pos_neg, 2) # wgan loss #g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan') g_loss_global, d_loss_global = gan_sn_pgan_loss( pos_global, neg_global, name='gan/global_gan') losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global losses['d_loss'] = d_loss_global # gp # Random Interpolate between true and false interpolates_global = random_interpolates( tf.concat([batch_pos, batch_mask], axis=3), tf.concat([batch_complete, batch_mask], axis=3)) dout_global = self.build_sn_pgan_discriminator(interpolates_global, reuse=True) # apply penalty penalty_global = gradients_penalty(interpolates_global, dout_global, mask=batch_mask) losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty_global #losses['d_loss'] = losses['d_loss'] + losses['gp_loss'] if summary and not config.PRETRAIN_COARSE_NETWORK: gradients_summary(g_loss_global, batch_predicted, name='g_loss_global') scalar_summary('convergence/d_loss', losses['d_loss']) scalar_summary('convergence/global_d_loss', d_loss_global) scalar_summary('gan_sn_pgan_loss/gp_loss', losses['gp_loss']) scalar_summary('gan_sn_pgan_loss/gp_penalty_global', penalty_global) if summary and not config.PRETRAIN_COARSE_NETWORK: # summary the magnitude of gradients from different losses w.r.t. predicted image gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1') gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1') gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2') gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') if config.PRETRAIN_COARSE_NETWORK: losses['g_loss'] = 0 else: losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss'] logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA) logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) if config.AE_LOSS: losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA) g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net', exclusionmask=None): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ multires = config.MULTIRES xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) hasmask = False #TODO: #exclusionmask is not None if hasmask: exclusionmask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5), tf.float32) #x = tf.concat([x, exclusionmask], axis=3) use_gating = config.GATING # two stage network cnum = 24 if use_gating else 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='conv3', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv5', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv6', gating=use_gating) mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv11', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv12', gating=use_gating) x = gen_deconv(x, 2 * cnum, name='conv13_upsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='conv14', gating=use_gating) x = gen_deconv(x, cnum, name='conv15_upsample', gating=use_gating) x = gen_conv(x, cnum // 2, 3, 1, name='conv16', gating=use_gating) x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x * mask + xin * (1. - mask) x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) #if hasmask: # xnow = tf.concat([xnow, exclusionmask], axis=3) x = gen_conv(xnow, cnum, 5, 1, name='xconv1', gating=use_gating) x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous', gating=use_gating) x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1', gating=use_gating) x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu, gating=use_gating) flows = [] use_attentionmask = hasmask and config.ATTENTION_MASK if use_attentionmask: ex_mask_s = resize_mask_like(exclusionmask, x) if multires: #scale down feature map, run contextual attention, scale up and paste inpainted region into original feature map logger.info('USING MULTIRES') logger.info('original x shape: ' + str(x.shape)) logger.info('original mask shape: ' + str(mask_s.shape)) x_multi = [x] mask_multi = [mask_s] if use_attentionmask: exclusion_mask_multi = [ex_mask_s] for i in range(config.LEVELS - 1): #x = gen_conv(x, 4*cnum, 3, 2, name='pyramid_downsample_'+str(i+1)) x = resize(x, scale=0.5) x_multi.append(x) mask_multi.append(resize_mask_like(mask_s, x)) if use_attentionmask: exclusion_mask_multi.append( resize_mask_like(ex_mask_s, x)) logger.info('exclusionmask shape: ' + str(exclusion_mask_multi[i + 1].shape)) logger.info('x shape: ' + str(x_multi[i + 1].shape)) logger.info('mask shape: ' + str(mask_multi[i + 1].shape)) x_multi.reverse() mask_multi.reverse() if use_attentionmask: exclusion_mask_multi.reverse() for i in range(config.LEVELS - 1): if use_attentionmask: totalmask = mask_multi[i] + exclusion_mask_multi[i] print('total mask shape:', totalmask.shape) else: totalmask = tf.tile(mask_multi[i], [config.BATCH_SIZE, 1, 1, 1]) x, flow = contextual_attention(x, x, totalmask, ksize=config.PATCH_KSIZE, stride=config.PATCH_STRIDE, rate=config.PATCH_RATE) #x, flow = contextual_attention(x, x, mask_multi[i], ksize=3, stride=1, rate=1) flows.append(flow) x = resize( x, scale=2 ) #TODO: look into using deconv instead of just upsampling x = x * mask_multi[i + 1] + x_multi[i + 1] * ( 1. - mask_multi[i + 1]) logger.info('upsampled x shape: ' + str(x.shape)) x, offset_flow = contextual_attention( x, x, tf.tile(mask_s, [config.BATCH_SIZE, 1, 1, 1]) if not use_attentionmask else mask_s + ex_mask_s, ksize=config.PATCH_KSIZE, stride=config.PATCH_STRIDE, rate=config.PATCH_RATE) flows.append(offset_flow) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10', gating=use_gating) pm = x x = tf.concat([x_hallu, pm], axis=3) #join branches together x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12', gating=use_gating) x = gen_deconv(x, 2 * cnum, name='allconv13_upsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14', gating=use_gating) x = gen_deconv(x, cnum, name='allconv15_upsample', gating=use_gating) x = gen_conv(x, cnum // 2, 3, 1, name='allconv16', gating=use_gating) x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, flows
def build_graph_with_losses(self, batch_data, config, training=True, summary=False, reuse=False, exclusionmask=None, mask=None): batch_pos = batch_data / 127.5 - 1. # generate mask, 1 represents masked point use_local_patch = False if mask is None: bbox = random_bbox(config) mask = bbox2mask(bbox, config, name='mask_c') if config.GAN == 'wgan_gp': use_local_patch = True else: #bbox = (0, 0, config.IMG_SHAPES[0], config.IMG_SHAPES[1]) mask = tf.cast(tf.less(0.5, mask[:, :, :, 0:1]), tf.float32) if config.INVERT_MASK: mask = 1 - mask batch_incomplete = batch_pos * (1. - mask) if exclusionmask is not None: if config.INVERT_EXCLUSIONMASK: loss_mask = tf.cast(tf.less(0.5, exclusionmask[:, :, :, 0:1]), tf.float32) #keep white parts else: loss_mask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5), tf.float32) #keep black parts batch_incomplete = batch_incomplete * loss_mask batch_pos = batch_pos * loss_mask x1, x2, offset_flow = self.build_inpaint_net( batch_incomplete, mask, config, reuse=reuse, training=training, padding=config.PADDING, exclusionmask=exclusionmask) if config.PRETRAIN_COARSE_NETWORK: batch_predicted = x1 logger.info('Set batch_predicted to x1.') else: batch_predicted = x2 logger.info('Set batch_predicted to x2.') losses = {} # apply mask and complete image batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask) if exclusionmask is not None: batch_complete = batch_complete * loss_mask l1_alpha = config.COARSE_L1_ALPHA # local patches if use_local_patch: local_patch_batch_pos = local_patch(batch_pos, bbox) local_patch_batch_predicted = local_patch(batch_predicted, bbox) local_patch_x1 = local_patch(x1, bbox) local_patch_x2 = local_patch(x2, bbox) local_patch_batch_complete = local_patch(batch_complete, bbox) local_patch_mask = local_patch(mask, bbox) losses['l1_loss'] = l1_alpha * tf.reduce_mean( tf.abs(local_patch_batch_pos - local_patch_x1) * spatial_discounting_mask(config) * (loss_mask if exclusionmask is not None else 1)) if not config.PRETRAIN_COARSE_NETWORK: losses['l1_loss'] += tf.reduce_mean( tf.abs(local_patch_batch_pos - local_patch_x2) * spatial_discounting_mask(config) * (loss_mask if exclusionmask is not None else 1)) losses['ae_loss'] = l1_alpha * tf.reduce_mean( tf.abs(batch_pos - x1) * (1. - mask) * (loss_mask if exclusionmask is not None else 1)) if not config.PRETRAIN_COARSE_NETWORK: losses['ae_loss'] += tf.reduce_mean( tf.abs(batch_pos - x2) * (1. - mask) * (loss_mask if exclusionmask is not None else 1)) losses['ae_loss'] /= tf.reduce_mean(1. - mask) else: losses['l1_loss'] = l1_alpha * tf.reduce_mean( tf.abs(batch_pos - x1) * (loss_mask if exclusionmask is not None else 1)) if not config.PRETRAIN_COARSE_NETWORK: losses['l1_loss'] += tf.reduce_mean( tf.abs(batch_pos - x2) * (loss_mask if exclusionmask is not None else 1)) if summary: scalar_summary('losses/l1_loss', losses['l1_loss']) if use_local_patch: scalar_summary('losses/ae_loss', losses['ae_loss']) img_size = [dim for dim in batch_incomplete.shape] img_size[2] = 5 border = tf.zeros(tf.TensorShape(img_size)) viz_img = [ batch_pos, border, batch_incomplete, border, batch_complete, border ] if not config.PRETRAIN_COARSE_NETWORK: batch_complete_coarse = x1 * mask + batch_incomplete * (1. - mask) viz_img.append(batch_complete_coarse) viz_img.append(border) if offset_flow is not None: scale = 2 << len(offset_flow) for flow in offset_flow: viz_img.append( resize(flow, scale=scale, func=tf.image.resize_nearest_neighbor)) viz_img.append(border) scale >>= 1 images_summary(tf.concat(viz_img, axis=2), 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) # gan batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) # local deterministic patch if config.GAN_WITH_MASK: batch_pos_neg = tf.concat([ batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1]) ], axis=3) # wgan with gradient penalty if config.GAN == 'wgan_gp': if not use_local_patch: raise Exception('wgan_gp requires global and local patch') local_patch_batch_pos_neg = tf.concat( [local_patch_batch_pos, local_patch_batch_complete], 0) # seperate gan pos_neg_local, pos_neg_global = self.build_wgan_discriminator( local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse) pos_local, neg_local = tf.split(pos_neg_local, 2) pos_global, neg_global = tf.split(pos_neg_global, 2) # wgan loss g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan') g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan') losses[ 'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local losses['d_loss'] = d_loss_global + d_loss_local # gp interpolates_local = random_interpolates( local_patch_batch_pos, local_patch_batch_complete) interpolates_global = random_interpolates(batch_pos, batch_complete) dout_local, dout_global = self.build_wgan_discriminator( interpolates_local, interpolates_global, reuse=True) # apply penalty penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask) penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask) losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global) losses['d_loss'] = losses['d_loss'] + losses['gp_loss'] if summary and not config.PRETRAIN_COARSE_NETWORK: gradients_summary(g_loss_local, batch_predicted, name='g_loss_local') gradients_summary(g_loss_global, batch_predicted, name='g_loss_global') scalar_summary('convergence/d_loss', losses['d_loss']) scalar_summary('convergence/local_d_loss', d_loss_local) scalar_summary('convergence/global_d_loss', d_loss_global) scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss']) scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local) scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global) elif config.GAN == 'sngan': if use_local_patch: raise Exception( 'sngan incompatible with global and local patch') pos_neg = self.build_sngan_discriminator(batch_pos_neg, name='discriminator', reuse=reuse) pos, neg = tf.split(pos_neg, 2) g_loss, d_loss = gan_hinge_loss(pos, neg) losses['g_loss'] = g_loss losses['d_loss'] = d_loss if summary and not config.PRETRAIN_COARSE_NETWORK: gradients_summary(g_loss, batch_predicted, name='g_loss') scalar_summary('convergence/d_loss', losses['d_loss']) else: losses['g_loss'] = 0 if summary and not config.PRETRAIN_COARSE_NETWORK: # summary the magnitude of gradients from different losses w.r.t. predicted image gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1') gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1') gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2') if use_local_patch: gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') if config.PRETRAIN_COARSE_NETWORK: losses['g_loss'] = 0 else: losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss'] logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA) logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) if config.AE_LOSS and use_local_patch: losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA) g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') return g_vars, d_vars, losses