def call(self, x, mask, training): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) cnum = 48 # stage1 x = self.s1_1(x) mask_s = resize_mask_like(mask, x) x = self.s1_2(x) x = tf.keras.activations.tanh(x) x_stage1 = x # stage2 # paste result as input x = x * mask + xin[:, :, :, 0:3] * (1. - mask) x.set_shape(xin[:, :, :, 0:3].get_shape().as_list()) # conv branch xnow = x x = self.s2(xnow) x_hallu = x # attention branch x = self.attn(xnow) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = self.attn_2(x) pm = x x = tf.concat([x_hallu, pm], axis=3) # final part x_stage2 = self.final(x) x_stage2 = tf.keras.activations.tanh(x_stage2) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x*mask], axis=3) # two stage network cnum = 48 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='conv3') x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='conv5') x = gen_conv(x, 4*cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous') x = gen_conv(x, 4*cnum, 3, 1, name='conv11') x = gen_conv(x, 4*cnum, 3, 1, name='conv12') x = gen_deconv(x, 2*cnum, name='conv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum//2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.nn.tanh(x) x_stage1 = x # stage2, paste result as input x = x*mask + xin[:, :, :, 0:3]*(1.-mask) x.set_shape(xin[:, :, :, 0:3].get_shape().as_list()) # conv branch # xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) xnow = x x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='xconv3') x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='xconv5') x = gen_conv(x, 4*cnum, 3, 1, name='xconv6') x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4*cnum, 3, 1, name='allconv11') x = gen_conv(x, 4*cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2*cnum, name='allconv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum//2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x = tf.nn.tanh(x) x_stage2 = x return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) cnum = 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv3') x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='conv5') x = gen_conv(x, 4 * cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous') x = gen_conv(x, 4 * cnum, 3, 1, name='conv11') x = gen_conv(x, 4 * cnum, 3, 1, name='conv12') x = gen_deconv(x, 2 * cnum, name='conv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x x = x * (mask) + xin * (1. - mask) x.set_shape(xin.get_shape().as_list()) xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3') x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6') x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='allconv16') x = tf.layers.conv2d(x, 3, 3, 1, activation=None, dilation_rate=1, padding='SAME', name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. ` Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ print("original_x", x.shape) print("original_mask", mask.shape) xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] print("ones_x", ones_x.shape) x = tf.concat([x, ones_x, ones_x*mask], axis=3) print("xxxx", x.shape) # two stage network cnum = 32 with tf.variable_scope(name, reuse=reuse), arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1') print("conv1", x.shape) x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='conv3') x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='conv5') x = gen_conv(x, 4*cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous') x = gen_conv(x, 4*cnum, 3, 1, name='conv11') x = gen_conv(x, 4*cnum, 3, 1, name='conv12') x = gen_deconv(x, 2*cnum, name='conv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum//2, 3, 1, name='conv16') print("conv16", x.shape) x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') print("conv17", x.shape) x = tf.clip_by_value(x, -1., 1.) # (32 28 28 3) print("x_clip", x.shape) x_stage1 = x print("x_stage", x.shape) # (32 28 28 3) # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x*mask + xin*(1.-mask) print("x_mask", x.shape) # 这里有问题 x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) print('xnow', xnow) x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='xconv3') x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='xconv5') x = gen_conv(x, 4*cnum, 3, 1, name='xconv6') x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4*cnum, 3, 1, name='allconv11') x = gen_conv(x, 4*cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2*cnum, name='allconv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum//2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) # two stage network cnum = 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv3') x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='conv5') x = gen_conv(x, 4 * cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous') x = gen_conv(x, 4 * cnum, 3, 1, name='conv11') x = gen_conv(x, 4 * cnum, 3, 1, name='conv12') x = gen_deconv(x, 2 * cnum, name='conv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x * mask + xin * (1. - mask) x_gb_in = x x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3') x_feat1 = x x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6') x_feat2 = x x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x x_feat3 = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10') pm = x # gradient branch if config.ADD_GRADIENT_BRANCH: grad_x_stage1 = self.get_grad.get_gradient_tf(x_gb_in) x = gen_conv(grad_x_stage1, cnum, 5, 1, name='gbconv1') x = gen_conv(x, 2 * cnum, 3, 2, name='gbconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='gbconv3') x = tf.concat([x, x_feat1], axis=3) # 融合主网络的修复信息 x = gen_conv(x, 2 * cnum, 3, 1, name='gbfushion1') x = gen_conv(x, 4 * cnum, 3, 2, name='gbconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv6') x = tf.concat([x, x_feat2], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='gbfushiuon2') x = gen_conv(x, 4 * cnum, 3, rate=2, name='gbconv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='gbconv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='gbconv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='gbconv10_atrous') x = tf.concat([x, x_feat3], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='gbfushiuon3') gb = x x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv12') x = gen_deconv(x, 2 * cnum, name='gbconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='gbconv14') x = gen_deconv(x, cnum, name='gbconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='gbconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='gbconv17') x_gb = tf.clip_by_value(x, -1., 1.) x = tf.concat([x_hallu, pm, gb], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, x_gb, grad_x_stage1, offset_flow else: x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_v2_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x*mask], axis=3) #images_summary(x, 'dbg', 3) # two stage network cnum = 24 # reduce? 25% of weights(according to paper). with tf.variable_scope(name, reuse=reuse), \ arg_scope([gated_conv, gated_deconv, gen_conv], training=training, padding=padding): # stage1 x = gated_conv(x, cnum, 5, 1, name='conv1') x = gated_conv(x, 2*cnum, 3, 2, name='conv2_downsample') x = gated_conv(x, 2*cnum, 3, 1, name='conv3') x = gated_conv(x, 4*cnum, 3, 2, name='conv4_downsample') x = gated_conv(x, 4*cnum, 3, 1, name='conv5') x = gated_conv(x, 4*cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = gated_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous') x = gated_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous') x = gated_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous') x = gated_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous') x = gated_conv(x, 4*cnum, 3, 1, name='conv11') x = gated_conv(x, 4*cnum, 3, 1, name='conv12') x = gated_deconv(x, 2*cnum, name='conv13_upsample') x = gated_conv(x, 2*cnum, 3, 1, name='conv14') x = gated_deconv(x, cnum, name='conv15_upsample') x = gated_conv(x, cnum//2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') # ~~~ last layer: no activation: not gated but just conv x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x*mask + xin*(1.-mask) x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) x = gated_conv(xnow, cnum, 5, 1, name='xconv1') x = gated_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gated_conv(x, 2*cnum, 3, 1, name='xconv3') x = gated_conv(x, 2*cnum, 3, 2, name='xconv4_downsample') x = gated_conv(x, 4*cnum, 3, 1, name='xconv5') x = gated_conv(x, 4*cnum, 3, 1, name='xconv6') x = gated_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous') x = gated_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous') x = gated_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous') x = gated_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x # attention branch x = gated_conv(xnow, cnum, 5, 1, name='pmconv1') x = gated_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gated_conv(x, 2*cnum, 3, 1, name='pmconv3') x = gated_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample') x = gated_conv(x, 4*cnum, 3, 1, name='pmconv5') x = gated_conv(x, 4*cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gated_conv(x, 4*cnum, 3, 1, name='pmconv9') x = gated_conv(x, 4*cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gated_conv(x, 4*cnum, 3, 1, name='allconv11') x = gated_conv(x, 4*cnum, 3, 1, name='allconv12') x = gated_deconv(x, 2*cnum, name='allconv13_upsample') x = gated_conv(x, 2*cnum, 3, 1, name='allconv14') x = gated_deconv(x, cnum, name='allconv15_upsample') x = gated_conv(x, cnum//2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') # ~~~ last layer: no activation: not gated but just conv x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) # two stage network cnum = 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 # cnum表示输出的维数 x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv3') x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='conv5') # x = gen_conv(x, 4*cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) # 并联跨步卷积 start x_first = x x = gen_conv(x_first, 4 * cnum, 3, rate=1, name='conv7_atrous_top') x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv8_atrous_top') x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv9_atrous_top') x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv10_atrous_top') first_top = x x = gen_conv(x_first, 4 * cnum, 3, rate=2, name='conv7_atrous_bottom') x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous_bottom') x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous_bottom') x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous_bottom') first_bottom = x x = tf.concat([first_top, first_bottom], 3) # 并联跨步卷积 end x = gen_conv(x, 4 * cnum, 3, 1, name='conv11') x = gen_conv(x, 4 * cnum, 3, 1, name='conv12') x = gen_deconv(x, 2 * cnum, name='conv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x * mask + xin * (1. - mask) x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3') x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6') x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net', exclusionmask=None): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ multires = config.MULTIRES xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) hasmask = False #TODO: #exclusionmask is not None if hasmask: exclusionmask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5), tf.float32) #x = tf.concat([x, exclusionmask], axis=3) use_gating = config.GATING # two stage network cnum = 24 if use_gating else 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='conv3', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv5', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv6', gating=use_gating) mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv11', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='conv12', gating=use_gating) x = gen_deconv(x, 2 * cnum, name='conv13_upsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='conv14', gating=use_gating) x = gen_deconv(x, cnum, name='conv15_upsample', gating=use_gating) x = gen_conv(x, cnum // 2, 3, 1, name='conv16', gating=use_gating) x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x * mask + xin * (1. - mask) x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) #if hasmask: # xnow = tf.concat([xnow, exclusionmask], axis=3) x = gen_conv(xnow, cnum, 5, 1, name='xconv1', gating=use_gating) x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous', gating=use_gating) x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1', gating=use_gating) x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu, gating=use_gating) flows = [] use_attentionmask = hasmask and config.ATTENTION_MASK if use_attentionmask: ex_mask_s = resize_mask_like(exclusionmask, x) if multires: #scale down feature map, run contextual attention, scale up and paste inpainted region into original feature map logger.info('USING MULTIRES') logger.info('original x shape: ' + str(x.shape)) logger.info('original mask shape: ' + str(mask_s.shape)) x_multi = [x] mask_multi = [mask_s] if use_attentionmask: exclusion_mask_multi = [ex_mask_s] for i in range(config.LEVELS - 1): #x = gen_conv(x, 4*cnum, 3, 2, name='pyramid_downsample_'+str(i+1)) x = resize(x, scale=0.5) x_multi.append(x) mask_multi.append(resize_mask_like(mask_s, x)) if use_attentionmask: exclusion_mask_multi.append( resize_mask_like(ex_mask_s, x)) logger.info('exclusionmask shape: ' + str(exclusion_mask_multi[i + 1].shape)) logger.info('x shape: ' + str(x_multi[i + 1].shape)) logger.info('mask shape: ' + str(mask_multi[i + 1].shape)) x_multi.reverse() mask_multi.reverse() if use_attentionmask: exclusion_mask_multi.reverse() for i in range(config.LEVELS - 1): if use_attentionmask: totalmask = mask_multi[i] + exclusion_mask_multi[i] print('total mask shape:', totalmask.shape) else: totalmask = tf.tile(mask_multi[i], [config.BATCH_SIZE, 1, 1, 1]) x, flow = contextual_attention(x, x, totalmask, ksize=config.PATCH_KSIZE, stride=config.PATCH_STRIDE, rate=config.PATCH_RATE) #x, flow = contextual_attention(x, x, mask_multi[i], ksize=3, stride=1, rate=1) flows.append(flow) x = resize( x, scale=2 ) #TODO: look into using deconv instead of just upsampling x = x * mask_multi[i + 1] + x_multi[i + 1] * ( 1. - mask_multi[i + 1]) logger.info('upsampled x shape: ' + str(x.shape)) x, offset_flow = contextual_attention( x, x, tf.tile(mask_s, [config.BATCH_SIZE, 1, 1, 1]) if not use_attentionmask else mask_s + ex_mask_s, ksize=config.PATCH_KSIZE, stride=config.PATCH_STRIDE, rate=config.PATCH_RATE) flows.append(offset_flow) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10', gating=use_gating) pm = x x = tf.concat([x_hallu, pm], axis=3) #join branches together x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11', gating=use_gating) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12', gating=use_gating) x = gen_deconv(x, 2 * cnum, name='allconv13_upsample', gating=use_gating) x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14', gating=use_gating) x = gen_deconv(x, cnum, name='allconv15_upsample', gating=use_gating) x = gen_conv(x, cnum // 2, 3, 1, name='allconv16', gating=use_gating) x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, flows
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x * mask], axis=3) if (config.GATED_CONVOLUTIONS): used_conv = gen_gated_conv else: used_conv = gen_conv # two stage network cnum = 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([used_conv, gen_deconv], training=training, padding=padding): # stage1 x = used_conv(x, cnum, 5, 1, name='conv1') x = used_conv(x, 2 * cnum, 3, 2, name='conv2_downsample') x = used_conv(x, 2 * cnum, 3, 1, name='conv3') x = used_conv(x, 4 * cnum, 3, 2, name='conv4_downsample') x = used_conv(x, 4 * cnum, 3, 1, name='conv5') x = used_conv(x, 4 * cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = used_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous') x = used_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous') x = used_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous') x = used_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous') x = used_conv(x, 4 * cnum, 3, 1, name='conv11') x = used_conv(x, 4 * cnum, 3, 1, name='conv12') x = gen_deconv(x, 2 * cnum, name='conv13_upsample') x = used_conv(x, 2 * cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = used_conv(x, cnum // 2, 3, 1, name='conv16') x = used_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) x = x * mask + xin * (1. - mask) x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) x = used_conv(xnow, cnum, 5, 1, name='xconv1') x = used_conv(x, cnum, 3, 2, name='xconv2_downsample') x = used_conv(x, 2 * cnum, 3, 1, name='xconv3') x = used_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample') x = used_conv(x, 4 * cnum, 3, 1, name='xconv5') x = used_conv(x, 4 * cnum, 3, 1, name='xconv6') x = used_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous') x = used_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous') x = used_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous') x = used_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous') if (config.NO_HALLUC): # turn off hallucination pathway for ablation study zeros = tf.zeros(shape=x.shape, dtype=tf.float32, name=None) x_hallu = tf.multiply(x, zeros) # bit-wise multiplication else: x_hallu = x # attention branch x = used_conv(xnow, cnum, 5, 1, name='pmconv1') x = used_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = used_conv(x, 2 * cnum, 3, 1, name='pmconv3') x = used_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample') x = used_conv(x, 4 * cnum, 3, 1, name='pmconv5') x = used_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = used_conv(x, 4 * cnum, 3, 1, name='pmconv9') x = used_conv(x, 4 * cnum, 3, 1, name='pmconv10') if (config.NO_ATTENTION): # turn off attention pathway for ablation study zeros = tf.zeros(shape=x.shape, dtype=tf.float32, name=None) pm = tf.multiply(x, zeros) else: pm = x x = tf.concat([x_hallu, pm], axis=3) x = used_conv(x, 4 * cnum, 3, 1, name='allconv11') x = used_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = used_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = used_conv(x, cnum // 2, 3, 1, name='allconv16') x = used_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def __call__(self, x, mask, return_offset=False): xin = x if mask.shape[1] == 1: # no edge image mask = F.concat([mask, self.xp.zeros_like(x[:, :1])]) x = F.concat([x, mask], axis=1) x = self.conv1(x) x = self.conv2_downsample(x) x = self.conv3(x) x = self.conv4_downsample(x) x = self.conv5(x) x = self.conv6(x) mask_s = resize_mask_like(mask[:, :1], x) x = self.conv7_atrous(x) x = self.conv8_atrous(x) x = self.conv9_atrous(x) x = self.conv10_atrous(x) x = self.conv11(x) x = self.conv12(x) x = self.conv13_upsample(x) x = self.conv14(x) x = self.conv15_upsample(x) x = self.conv16(x) x = self.conv17(x) x = F.clip(x, -1., 1.) x_stage1 = x # stage2, paste result as input x = x * mask[:, :1] + xin * (1. - mask[:, :1]) # conv branch xnow = F.concat([x, mask], axis=1) x = self.xconv1(xnow) x = self.xconv2_downsample(x) x = self.xconv3(x) x = self.xconv4_downsample(x) x = self.xconv5(x) x = self.xconv6(x) x = self.xconv7_atrous(x) x = self.xconv8_atrous(x) x = self.xconv9_atrous(x) x = self.xconv10_atrous(x) x_hallu = x # attention branch x = self.pmconv1(xnow) x = self.pmconv2_downsample(x) x = self.pmconv3(x) x = self.pmconv4_downsample(x) x = self.pmconv5(x) x = self.pmconv6(x) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2, return_flow=return_offset) x = self.pmconv9(x) x = self.pmconv10(x) # pm = x x = F.concat([x_hallu, x], axis=1) x = self.allconv11(x) x = self.allconv12(x) x = self.allconv13_upsample(x) x = self.allconv14(x) x = self.allconv15_upsample(x) x = self.allconv16(x) x = self.allconv17(x) x_stage2 = F.clip(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] xsurf_mask = tf.clip_by_value(((x + 1.) * 127.5), 0., 1.) x = tf.concat([x, ones_x, ones_x * mask], axis=3) # two stage network cnum = 32 with tf.variable_scope(name, reuse=tf.AUTO_REUSE), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 ##################################################### #Tensorflow Map Function (RGB + Disparity) = Input # #-> Input = map_function(RGB, Disparity) # ##################################################### x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv3') x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='conv5') x = gen_conv(x, 4 * cnum, 3, 1, name='conv6') mask_s = resize_mask_like(mask, x) x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous') x = gen_conv(x, 4 * cnum, 3, 1, name='conv11') x = gen_conv(x, 4 * cnum, 3, 1, name='conv12') x = gen_deconv(x, 2 * cnum, name='conv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x # return x_stage1, None, None # stage2, paste result as input # x = tf.stop_gradient(x) #cnum = 32 x = x * mask + xin * (1. - mask) x.set_shape(xin.get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) x_surf = surf_conv((tf.image.rgb_to_grayscale(x) + 1.) * 127.5) xnow_surf = tf.concat( [tf.image.rgb_to_grayscale(x), x_surf, ones_x, ones_x * mask], axis=3) x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3') x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6') x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x # attention branch x = gen_conv(xnow_surf, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = surface_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, x2, mask, config=None, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ xin = x ############ x2in = x2 ones_x2 = tf.ones_like(x)[:, :, :, 0:1] #no need x2 = tf.concat([x, ones_x2, ones_x2], axis=3) #no need ############## #####1st x_input [Dimension(2), Dimension(256), Dimension(256), Dimension(3)] offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] print("ones_x size", (ones_x.shape)) x = tf.concat([x, ones_x, ones_x * mask], axis=3) ###x_input [Dimension(2), Dimension(256), Dimension(256), Dimension(5)] mask2 = 1 - mask tf.print(mask2) print("x_input", (x.shape.dims)) # two stage network cnum = 32 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): ''' # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='conv3') x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='conv5') x = gen_conv(x, 4*cnum, 3, 1, name='conv6') print('*******************\n***************build model network***************\n************************') print("\nx_for mask_s",(x.shape.dims)) ###x_for mask_s [Dimension(2), Dimension(64), Dimension(64), Dimension(128)] mask_s = resize_mask_like(mask, x) mask_n=resize_mask_like(mask2, x) print("size of mask_s",mask_s.shape) ###size of mask_s (1, 64, 64, 1) x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous') x = gen_conv(x, 4*cnum, 3, 1, name='conv11') x = gen_conv(x, 4*cnum, 3, 1, name='conv12') x = gen_deconv(x, 2*cnum, name='conv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='conv14') x = gen_deconv(x, cnum, name='conv15_upsample') x = gen_conv(x, cnum//2, 3, 1, name='conv16') x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') x = tf.clip_by_value(x, -1., 1.) x_stage1 = x ''' # return x_stage1, None, None ###x_stage1 [Dimension(2), Dimension(256), Dimension(256), Dimension(3)] #print("x_stage1",x.shape.dims) # stage2, paste result as input # x = tf.stop_gradient(x) x_stage1 = xin x = x2in * mask + xin * (1. - mask) x_stage1 = x x.set_shape(xin.get_shape().as_list()) print("x input_stage2", x.shape.dims) #x input_stage2 [Dimension(2), Dimension(256), Dimension(256), Dimension(3)] # conv branch xnow = tf.concat([x, ones_x, ones_x * mask], axis=3) print("xnow input_stage2", xnow.shape) x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3') x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6') x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous') x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous') x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous') x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous') x_hallu = x mask_s = resize_mask_like(mask, x) mask_n = resize_mask_like(mask2, x) print("Here is the mASK_S", mask_s) print("x_hallu or conv layer output", (x.shape.dims)) ####x_hallu or conv layer output [Dimension(2), Dimension(64), Dimension(64), Dimension(128)] # attention branch #attention layer has the same input as conv layer which is xnow x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) print("x input to context layer", x.shape.dims) #x input to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)] f = x x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) print("x output to context layer", x.shape.dims) #x output to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)] x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10') pm = x ''' #####################2nd attention layer######################### #attention layer has the same input as conv layer which is xnow x = gen_conv(x2, cnum, 5, 1, name='naconv1') x = gen_conv(x, cnum, 3, 2, name='naconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='naconv3') x = gen_conv(x, 4*cnum, 3, 2, name='naconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='naconv5') x = gen_conv(x, 4*cnum, 3, 1, name='naconv6', activation=tf.nn.relu) print("new layer:x input to context layer",x.shape.dims) #x input to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)] x, offset_flow = contextual_attention(f, x, mask_n, 3, 1, rate=2) print("new layer:x output to context layer",x.shape.dims) #x output to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)] x = gen_conv(x, 4*cnum, 3, 1, name='naconv9') x = gen_conv(x, 4*cnum, 3, 1, name='naconv10') na = x ''' ################################# #x = tf.concat([x_hallu, pm,na], axis=3) x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11') x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2 * cnum, name='allconv13_upsample') x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum // 2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x_stage2 = tf.clip_by_value(x, -1., 1.) print("x_stage2", (x.shape.dims)) #x_stage2 [Dimension(2), Dimension(256), Dimension(256), Dimension(3)] print( '*******************\n***************build model network***************\n************************' ) return x_stage1, x_stage2, offset_flow
def build_inpaint_net(self, x, mask, reuse=False, training=True, padding='SAME', name='inpaint_net'): """Inpaint network. Args: x: incomplete image, [-1, 1] mask: mask region {0, 1} Returns: [-1, 1] as predicted image """ # padding='SAME' :自动填充0 xin = x offset_flow = None ones_x = tf.ones_like(x)[:, :, :, 0:1] x = tf.concat([x, ones_x, ones_x*mask], axis=3) #拼接:在像素维拼接 # two stage network cnum = 48 with tf.variable_scope(name, reuse=reuse), \ arg_scope([gen_conv, gen_deconv], training=training, padding=padding): # stage1 x = gen_conv(x, cnum, 5, 1, name='conv1') x = gen_conv(x, 2*cnum, 3, 1, name='conv2') conv2=x x = gen_conv(x, 4*cnum, 5, 2, name='conv3') x = gen_conv(x, 4*cnum, 3, 1, name='conv4') conv4=x x = gen_conv(x, 8*cnum, 3, 2, name='conv5') x = gen_conv(x, 8*cnum, 3, 1, name='conv6') conv6=x mask_s = resize_mask_like(mask, x) #调整mask的形状 # 位置是否正确还有待确定 x = gen_conv(x, 8*cnum, 3, 2, name='conv7') x = gen_conv(x, 8*cnum, 3, 1, name='conv8') conv8=x x = gen_conv(x, 16*cnum, 3, 2, name='conv9') x = gen_conv(x, 16*cnum, 3, 1, name='conv10') x = gen_deconv(x, 8*cnum, name='convd') convd=x #refinenet1 x = tf.concat([convd, conv8], axis=3) x = gen_conv(x, 8*cnum, 3, 1, name='r1conv1') x = gen_conv(x, 8*cnum, 3, 1, name='r1conv2') x = gen_conv(x, 8*cnum, 3, 1, name='r1conv3') r1conv3=x x = gen_conv(x, 8*cnum, 3, 1, name='r1conv4') x = gen_conv(x, 8*cnum, 3, 1, name='r1conv5') x = tf.add(x,r1conv3) # 残差连接 x = gen_conv(x, 8*cnum, 3, 1, name='r1conv6') x = gen_deconv(x, 4*cnum, name='r1convd') x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) #refinenet2 x = tf.concat([x, conv6], axis=3) x = gen_conv(x, 4*cnum, 3, 1, name='r2conv1') x = gen_conv(x, 4*cnum, 3, 1, name='r2conv2') x = gen_conv(x, 4*cnum, 3, 1, name='r2conv3') r2conv3=x x = gen_conv(x, 4*cnum, 3, 1, name='r2conv4') x = gen_conv(x, 4*cnum, 3, 1, name='r2conv5') x = tf.add(x,r2conv3) # 残差连接 x = gen_conv(x, 4*cnum, 3, 1, name='r2conv6') x = gen_deconv(x, 2*cnum, name='r2convd') #refinenet3 x = tf.concat([x, conv4], axis=3) x = gen_conv(x, 2*cnum, 3, 1, name='r3conv1') x = gen_conv(x, 2*cnum, 3, 1, name='r3conv2') x = gen_conv(x, 2*cnum, 3, 1, name='r3conv3') r3conv3=x x = gen_conv(x, 2*cnum, 3, 1, name='r3conv4') x = gen_conv(x, 2*cnum, 3, 1, name='r3conv5') x = tf.add(x,r3conv3) # 残差连接 x = gen_conv(x, 2*cnum, 3, 1, name='r3conv6') x = gen_deconv(x,cnum, name='r3convd') # stage2, paste result as input #x = x*mask + xin[:, :, :, 0:3]*(1.-mask) #x.set_shape(xin[:, :, :, 0:3].get_shape().as_list()) # conv branch xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) xnow = x x = gen_conv(xnow, cnum, 5, 1, name='1xconv1') x = gen_conv(x, cnum, 3, 2, name='1xconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='1xconv3') x = gen_conv(x, 2*cnum, 3, 2, name='1xconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='1xconv5') xconv5=x x = gen_conv(x, 4*cnum, 3, 1, name='1xconv6') x = gen_conv(x, 4*cnum, 3, 1, name='1xconv7') x = tf.add(x,xconv5) # 残差连接 x = gen_conv(x, 4*cnum, 3, 1, name='1xconv8') x = gen_conv(x, 4*cnum, 3, rate=2, name='1xconv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='1xconv10_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='1xconv11_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='1xconv12_atrous') x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='1pmconv1') x = gen_conv(x, cnum, 3, 2, name='1pmconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='1pmconv3') x = gen_conv(x, 4*cnum, 3, 2, name='1pmconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv5') x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv9') x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4*cnum, 3, 1, name='1allconv11') x = gen_conv(x, 4*cnum, 3, 1, name='1allconv12') x = gen_deconv(x, 2*cnum, name='1allconv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='1allconv14') x = gen_deconv(x, cnum, name='1allconv15_upsample') # x = gen_conv(x, cnum//2, 3, 1, name='allconv16') # x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x = tf.nn.tanh(x) #refinenet4 x = tf.concat([x, conv2], axis=3) x = gen_conv(x, cnum, 3, 1, name='r4conv1') x = gen_conv(x, cnum, 3, 1, name='r4conv2') x = gen_conv(x, cnum, 3, 1, name='r4conv3') r4conv3=x x = gen_conv(x, cnum, 3, 1, name='r4conv4') x = gen_conv(x, cnum, 3, 1, name='r4conv5') x = tf.add(x,r4conv3) # 残差连接 x = gen_conv(x, cnum, 3, 1, name='r4conv6') x = gen_conv(x, 3, 3, 1,activation=None, name='r4conv7') x = tf.nn.tanh(x) x_stage1 = x # stage2, paste result as input x = x*mask + xin[:, :, :, 0:3]*(1.-mask) x.set_shape(xin[:, :, :, 0:3].get_shape().as_list()) # conv branch # xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) xnow = x x = gen_conv(xnow, cnum, 5, 1, name='xconv1') x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='xconv3') x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='xconv5') xconv5=x x = gen_conv(x, 4*cnum, 3, 1, name='xconv6') x = gen_conv(x, 4*cnum, 3, 1, name='xconv7') x = tf.add(x,xconv5) # 残差连接 x = gen_conv(x, 4*cnum, 3, 1, name='xconv8') x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv9_atrous') x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv10_atrous') x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv11_atrous') x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv12_atrous') x_hallu = x # attention branch x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3') x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6', activation=tf.nn.relu) x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9') x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10') pm = x x = tf.concat([x_hallu, pm], axis=3) x = gen_conv(x, 4*cnum, 3, 1, name='allconv11') x = gen_conv(x, 4*cnum, 3, 1, name='allconv12') x = gen_deconv(x, 2*cnum, name='allconv13_upsample') x = gen_conv(x, 2*cnum, 3, 1, name='allconv14') x = gen_deconv(x, cnum, name='allconv15_upsample') x = gen_conv(x, cnum//2, 3, 1, name='allconv16') x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') x = tf.nn.tanh(x) x_stage2 = x return x_stage1, x_stage2, offset_flow